{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.00818634033203125,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92716b5e30794cbc99a22348045026cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyc/miniconda3/envs/ke2torch23cu121/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/home/lyc/miniconda3/envs/ke2torch23cu121/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/home/lyc/miniconda3/envs/ke2torch23cu121/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/home/lyc/miniconda3/envs/ke2torch23cu121/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "model_path = '/share/huggingface/Llama-2-7b-ms'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, device_map='auto')\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', torch_dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CausalLMOutputWithPast(loss=None, logits=tensor([[[-12.5625,  -7.1250,  -0.6289,  ...,  -6.6562,  -7.9375,  -7.3125],\n",
       "         [-10.4375,  -6.7500,  -0.7852,  ...,  -7.0312,  -6.6562,  -8.2500],\n",
       "         [ -9.1250,  -6.1250,   2.4219,  ...,  -2.4844,  -2.9688,  -3.9375],\n",
       "         ...,\n",
       "         [-11.7500, -12.1250,   0.3535,  ...,  -7.4062,  -6.6875,  -7.3125],\n",
       "         [-10.1250, -10.2500,   2.0469,  ...,  -5.4062,  -5.2812,  -4.0312],\n",
       "         [ -3.1875,  -2.1875,  11.7500,  ...,   0.2109,  -0.8125,  -0.6953]]],\n",
       "       device='cuda:0', grad_fn=<ToCopyBackward0>), past_key_values=((tensor([[[[-4.5117e-01, -1.6968e-02,  5.1270e-02,  ...,  5.2246e-02,\n",
       "           -2.3804e-02, -1.2598e-01],\n",
       "          [-8.0469e-01,  4.6484e-01,  4.5898e-02,  ...,  4.0430e-01,\n",
       "           -1.2402e-01,  6.0547e-01],\n",
       "          [ 1.0859e+00, -3.9453e-01, -5.7031e-01,  ..., -3.2227e-01,\n",
       "            1.6113e-01, -3.6328e-01],\n",
       "          ...,\n",
       "          [-3.8867e-01,  3.6328e-01, -1.6504e-01,  ..., -2.1094e-01,\n",
       "            2.3145e-01, -1.6895e-01],\n",
       "          [-5.2490e-02,  7.4219e-02,  2.9297e-01,  ..., -9.6436e-03,\n",
       "            1.0059e-01,  1.1292e-02],\n",
       "          [-7.6172e-02, -3.1738e-02, -4.7607e-03,  ...,  1.3477e-01,\n",
       "            8.5449e-02,  1.7480e-01]],\n",
       "\n",
       "         [[ 1.3516e+00,  1.0781e+00, -4.2969e-01,  ...,  4.7656e-01,\n",
       "           -3.0469e-01,  4.7852e-01],\n",
       "          [ 7.4219e-01,  7.0312e-01, -3.3789e-01,  ..., -1.9629e-01,\n",
       "            3.4180e-01, -2.2656e-01],\n",
       "          [-1.5312e+00, -1.4062e+00,  3.3789e-01,  ..., -2.6562e-01,\n",
       "            2.9492e-01, -2.6953e-01],\n",
       "          ...,\n",
       "          [ 9.6094e-01,  1.9531e-01, -3.0469e-01,  ..., -4.6094e-01,\n",
       "            4.6631e-02, -3.4180e-01],\n",
       "          [ 7.4609e-01,  6.1719e-01, -3.7354e-02,  ..., -4.8633e-01,\n",
       "            1.4355e-01, -3.5742e-01],\n",
       "          [ 4.8828e-01,  4.6484e-01, -6.0156e-01,  ...,  7.3438e-01,\n",
       "           -1.6504e-01,  6.2109e-01]],\n",
       "\n",
       "         [[ 2.0142e-03, -2.5781e-01, -4.1602e-01,  ..., -1.2207e-01,\n",
       "            4.9805e-01,  7.6172e-01],\n",
       "          [-3.3789e-01, -1.0059e-01, -4.1016e-01,  ..., -4.7070e-01,\n",
       "           -5.1562e-01, -4.6680e-01],\n",
       "          [ 3.0273e-02, -2.4609e-01,  4.0527e-02,  ...,  1.4609e+00,\n",
       "            1.4922e+00,  1.3438e+00],\n",
       "          ...,\n",
       "          [-4.5117e-01, -2.8516e-01, -3.6133e-01,  ...,  1.7656e+00,\n",
       "            1.8047e+00,  1.6484e+00],\n",
       "          [-5.3906e-01, -2.6953e-01, -6.6895e-02,  ...,  1.6562e+00,\n",
       "            1.6719e+00,  1.5547e+00],\n",
       "          [-1.5625e-01, -3.3203e-01,  1.4258e-01,  ..., -9.8828e-01,\n",
       "           -1.1406e+00, -1.1016e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.4658e-02, -4.8523e-03,  3.9795e-02,  ...,  4.8633e-01,\n",
       "            9.6484e-01, -4.0234e-01],\n",
       "          [ 9.6680e-02,  7.3730e-02,  2.7954e-02,  ...,  1.8750e-01,\n",
       "            5.8984e-01, -1.1133e-01],\n",
       "          [ 1.1719e-02,  1.8047e+00,  4.1406e-01,  ..., -7.9297e-01,\n",
       "           -1.8203e+00, -1.5469e+00],\n",
       "          ...,\n",
       "          [ 1.5625e-02, -1.0703e+00,  7.4609e-01,  ..., -6.0938e-01,\n",
       "           -1.0391e+00, -1.3359e+00],\n",
       "          [-1.0156e+00, -7.8516e-01, -1.6309e-01,  ..., -6.3281e-01,\n",
       "           -8.9062e-01, -1.1172e+00],\n",
       "          [-6.1719e-01, -2.6562e-01, -3.2617e-01,  ..., -3.2422e-01,\n",
       "           -2.5781e-01, -6.9141e-01]],\n",
       "\n",
       "         [[ 1.7285e-01, -5.3906e-01, -3.1641e-01,  ..., -2.7344e-01,\n",
       "            1.3672e-01,  1.4062e-01],\n",
       "          [-1.0254e-01,  4.3555e-01, -8.5547e-01,  ...,  2.5391e-01,\n",
       "           -1.2256e-01, -1.1816e-01],\n",
       "          [ 2.7344e-01,  4.9805e-01, -1.0469e+00,  ..., -6.6797e-01,\n",
       "            3.0859e-01,  3.0859e-01],\n",
       "          ...,\n",
       "          [-5.2344e-01, -3.2812e-01,  6.8750e-01,  ..., -3.5547e-01,\n",
       "            9.2773e-02,  9.1797e-02],\n",
       "          [-4.7656e-01, -7.8125e-01,  9.1797e-01,  ..., -2.8906e-01,\n",
       "            8.1055e-02,  8.5449e-02],\n",
       "          [-4.3945e-02, -2.3730e-01, -7.0703e-01,  ...,  2.1680e-01,\n",
       "           -3.6865e-02, -3.6133e-02]],\n",
       "\n",
       "         [[-4.5117e-01,  1.2266e+00,  2.9144e-03,  ...,  7.1484e-01,\n",
       "            2.1582e-01,  2.7734e-01],\n",
       "          [-3.9648e-01, -9.0625e-01, -9.7266e-01,  ...,  1.2656e+00,\n",
       "           -6.4062e-01,  3.6328e-01],\n",
       "          [-1.8945e-01,  6.1523e-02,  5.5078e-01,  ..., -2.3750e+00,\n",
       "            1.1406e+00, -3.2031e-01],\n",
       "          ...,\n",
       "          [-2.6367e-02,  4.1748e-02, -1.4160e-02,  ..., -1.5703e+00,\n",
       "            9.8438e-01, -2.4902e-01],\n",
       "          [-1.6406e-01, -5.2246e-02, -3.3691e-02,  ..., -8.3594e-01,\n",
       "            4.7070e-01, -1.2207e-01],\n",
       "          [ 6.0547e-02, -4.1602e-01, -6.6406e-01,  ...,  8.2812e-01,\n",
       "           -4.0430e-01, -8.5449e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-6.0425e-03, -6.3782e-03,  5.5847e-03,  ...,  1.3809e-03,\n",
       "            2.0264e-02, -5.5237e-03],\n",
       "          [ 3.8719e-04, -8.7891e-03, -3.2806e-03,  ...,  5.3406e-03,\n",
       "            1.2131e-03, -6.1646e-03],\n",
       "          [-5.3406e-03,  3.0823e-03, -2.9297e-03,  ...,  5.1022e-05,\n",
       "           -2.4719e-03,  6.9275e-03],\n",
       "          ...,\n",
       "          [ 3.1586e-03,  1.7166e-03,  6.1035e-04,  ...,  6.3782e-03,\n",
       "            2.3193e-03, -9.2506e-05],\n",
       "          [ 4.0894e-03, -2.8534e-03,  3.5400e-03,  ..., -2.3746e-04,\n",
       "            3.7079e-03, -8.7280e-03],\n",
       "          [-2.2736e-03,  6.0425e-03,  7.2632e-03,  ...,  9.8419e-04,\n",
       "            8.9111e-03, -7.1716e-03]],\n",
       "\n",
       "         [[ 2.0599e-03,  4.0283e-03, -9.7046e-03,  ..., -4.7607e-03,\n",
       "            6.9275e-03, -1.9043e-02],\n",
       "          [ 8.1539e-05, -9.0790e-04,  2.3956e-03,  ...,  9.7046e-03,\n",
       "            2.1057e-03, -3.3417e-03],\n",
       "          [ 4.5204e-04,  2.8229e-04, -8.2397e-04,  ...,  2.4109e-03,\n",
       "            1.6327e-03,  9.3079e-04],\n",
       "          ...,\n",
       "          [-5.4321e-03,  1.8082e-03, -6.9275e-03,  ...,  1.4420e-03,\n",
       "           -4.1504e-03,  2.6131e-04],\n",
       "          [-4.4556e-03,  2.1973e-03,  2.9907e-03,  ..., -1.8768e-03,\n",
       "           -6.1035e-03, -1.8997e-03],\n",
       "          [ 2.1973e-03, -4.6387e-03,  5.6458e-04,  ...,  2.0905e-03,\n",
       "           -3.4637e-03,  8.3008e-03]],\n",
       "\n",
       "         [[-3.8605e-03, -6.4392e-03,  1.4221e-02,  ...,  5.0354e-03,\n",
       "           -1.4648e-02,  6.9885e-03],\n",
       "          [ 7.9956e-03, -6.7139e-03, -6.6757e-04,  ...,  5.0354e-03,\n",
       "            2.9449e-03, -3.7231e-03],\n",
       "          [-1.1253e-04,  2.4109e-03,  4.4632e-04,  ...,  6.0272e-04,\n",
       "            5.4321e-03, -4.6997e-03],\n",
       "          ...,\n",
       "          [-2.6131e-04,  1.1444e-03, -3.2663e-05,  ...,  1.2112e-04,\n",
       "            5.2490e-03, -1.1597e-03],\n",
       "          [ 1.3428e-03, -3.3112e-03, -8.0109e-04,  ..., -7.9727e-04,\n",
       "           -4.8447e-04,  6.1035e-03],\n",
       "          [ 1.9836e-03,  1.1749e-03, -2.5635e-03,  ..., -1.4114e-03,\n",
       "           -7.9346e-03,  3.8528e-04]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-5.9509e-04, -2.0752e-03, -1.5747e-02,  ...,  2.0752e-02,\n",
       "            9.4604e-03, -5.1880e-03],\n",
       "          [-1.0376e-02,  4.0283e-02, -6.9336e-02,  ..., -9.7656e-02,\n",
       "           -3.1494e-02, -2.0752e-02],\n",
       "          [-1.1536e-02, -1.7822e-02, -1.2024e-02,  ...,  6.0120e-03,\n",
       "           -9.0820e-02,  2.4414e-02],\n",
       "          ...,\n",
       "          [-8.8501e-03, -2.4414e-02, -3.0884e-02,  ...,  2.1606e-02,\n",
       "           -4.6387e-03, -2.1729e-02],\n",
       "          [ 2.2827e-02, -1.9287e-02,  3.4912e-02,  ...,  3.1738e-02,\n",
       "           -3.1738e-02,  3.9795e-02],\n",
       "          [ 1.0071e-02, -4.4922e-02, -7.8735e-03,  ..., -1.7456e-02,\n",
       "            5.6641e-02,  1.8433e-02]],\n",
       "\n",
       "         [[ 4.3106e-04,  1.5503e-02, -1.0254e-02,  ..., -8.6060e-03,\n",
       "           -4.5166e-02, -2.5024e-02],\n",
       "          [-6.5918e-03,  2.0386e-02, -6.9885e-03,  ..., -6.9580e-03,\n",
       "            7.9632e-05, -4.5776e-03],\n",
       "          [ 2.1820e-03,  2.2583e-03,  3.4027e-03,  ..., -2.4128e-04,\n",
       "           -5.9509e-03,  4.4250e-04],\n",
       "          ...,\n",
       "          [ 1.7853e-03,  3.9062e-03, -7.5150e-04,  ..., -2.8610e-04,\n",
       "            5.4626e-03,  1.2207e-03],\n",
       "          [ 3.6316e-03,  4.6997e-03,  1.2207e-03,  ...,  5.6458e-03,\n",
       "            4.0894e-03, -1.0376e-02],\n",
       "          [ 8.3008e-03, -1.8477e-06, -2.0294e-03,  ...,  3.8147e-03,\n",
       "           -5.3101e-03,  1.6479e-03]],\n",
       "\n",
       "         [[-3.9978e-03, -3.9368e-03,  4.6730e-04,  ...,  1.4572e-03,\n",
       "           -6.7139e-04, -1.4420e-03],\n",
       "          [ 2.9144e-03,  1.0498e-02, -9.3384e-03,  ...,  1.4191e-03,\n",
       "           -9.7656e-03,  4.1809e-03],\n",
       "          [-1.6327e-03,  1.3809e-03,  2.3499e-03,  ..., -6.0654e-04,\n",
       "           -1.9226e-03,  1.8234e-03],\n",
       "          ...,\n",
       "          [ 2.9449e-03, -1.0529e-03, -6.4697e-03,  ..., -3.0136e-04,\n",
       "            5.5847e-03,  3.4790e-03],\n",
       "          [ 3.3569e-03, -1.3794e-02,  8.6670e-03,  ..., -4.6730e-04,\n",
       "            1.4191e-03,  2.9449e-03],\n",
       "          [-7.2632e-03,  2.6093e-03, -1.9989e-03,  ...,  1.4709e-02,\n",
       "           -5.5237e-03, -1.4648e-03]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-0.3184,  0.7617,  0.5352,  ...,  0.1494,  0.1118, -0.2021],\n",
       "          [-1.7656,  0.1016, -0.6992,  ..., -0.8164, -0.5859,  0.2930],\n",
       "          [-0.7148,  0.0466, -1.2734,  ..., -0.5312, -0.3438, -0.0903],\n",
       "          ...,\n",
       "          [ 1.4375,  0.0913, -0.1094,  ..., -0.9375, -0.9219,  0.2139],\n",
       "          [-0.5703,  0.1055,  0.6055,  ..., -0.8320, -1.0938,  0.6172],\n",
       "          [-1.1953,  0.3926,  1.2266,  ..., -0.4551, -1.2344,  0.8438]],\n",
       "\n",
       "         [[-0.1016, -0.3770,  0.0405,  ...,  0.6875,  0.6758,  0.6680],\n",
       "          [ 0.1514, -0.1484, -1.0156,  ..., -1.1328, -0.4121, -0.3633],\n",
       "          [ 0.1270,  0.3047, -0.1865,  ..., -0.2217,  0.2178,  0.3594],\n",
       "          ...,\n",
       "          [-0.5312,  0.0752,  0.7539,  ..., -0.8516,  0.0728, -0.2168],\n",
       "          [ 0.1445, -0.7422,  1.0000,  ..., -0.8984, -0.3438, -0.5273],\n",
       "          [ 1.3438, -0.8086, -0.2617,  ..., -0.7734, -0.4062, -0.5664]],\n",
       "\n",
       "         [[-0.0544,  0.4590,  0.4805,  ...,  1.1875,  0.7227,  0.1001],\n",
       "          [-0.1074, -0.2793, -0.2402,  ...,  0.8477,  0.3984,  0.2715],\n",
       "          [-0.0762, -0.2637,  0.0835,  ...,  1.2344,  0.6836,  0.4141],\n",
       "          ...,\n",
       "          [ 0.3574,  0.3633,  0.2168,  ...,  1.1016,  0.4883,  0.5664],\n",
       "          [ 0.5195,  0.1377,  0.0713,  ...,  1.0781,  0.5352,  0.4629],\n",
       "          [ 0.3340, -0.0474, -0.0713,  ...,  1.0000,  0.3906,  0.5117]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-0.7070,  0.0168,  0.0400,  ...,  0.9219, -0.3340, -0.0630],\n",
       "          [-0.1504, -0.1182, -0.1670,  ..., -0.0684, -0.1631,  0.0532],\n",
       "          [ 0.6133, -0.1455,  0.1533,  ...,  0.5820, -0.2773,  0.2314],\n",
       "          ...,\n",
       "          [-0.6016,  0.0640, -0.1777,  ..., -0.2539, -0.0356,  0.1396],\n",
       "          [-0.0059, -0.0669, -0.4492,  ..., -0.0022, -0.1309,  0.1650],\n",
       "          [ 0.3105, -0.1611, -0.3008,  ..., -0.0208, -0.0027,  0.0928]],\n",
       "\n",
       "         [[ 0.2217,  0.2490, -0.1631,  ..., -0.4531,  0.7031, -0.0097],\n",
       "          [-0.0352, -0.4199,  0.0034,  ..., -0.1973, -0.0830, -0.3555],\n",
       "          [-0.2402, -0.1484,  0.1348,  ..., -0.6094,  0.6016,  0.0085],\n",
       "          ...,\n",
       "          [ 0.4043, -0.0337, -0.0308,  ...,  0.1001, -0.2441, -0.0776],\n",
       "          [ 0.5742, -0.1943,  0.4062,  ..., -0.1050, -0.3008,  0.3672],\n",
       "          [-0.1914, -0.4844,  0.7383,  ...,  0.1279, -0.5117,  0.1001]],\n",
       "\n",
       "         [[-0.0182,  0.0591,  0.0898,  ...,  0.2012, -0.2031,  0.1094],\n",
       "          [-0.4707,  0.4551,  0.4414,  ...,  1.0781, -1.1250,  0.7344],\n",
       "          [ 0.1367,  0.0889,  0.2598,  ...,  0.7383, -0.7891,  0.4434],\n",
       "          ...,\n",
       "          [-0.3945, -0.3984, -0.2168,  ...,  1.2188, -1.2422,  0.8555],\n",
       "          [-0.8945, -0.2656, -0.3184,  ...,  1.3125, -1.3516,  0.9492],\n",
       "          [-0.7891,  0.0869, -0.0615,  ...,  1.4062, -1.4141,  1.0078]]]],\n",
       "       device='cuda:0', dtype=torch.bfloat16, grad_fn=<AddBackward0>), tensor([[[[ 8.9355e-02, -4.1199e-03, -3.0518e-02,  ...,  1.6846e-02,\n",
       "           -1.8311e-02,  6.5430e-02],\n",
       "          [-1.0498e-02,  1.1084e-01,  4.8584e-02,  ...,  3.7598e-02,\n",
       "           -1.4465e-02, -6.3965e-02],\n",
       "          [ 8.1543e-02,  6.9580e-03,  1.1658e-02,  ...,  1.1902e-02,\n",
       "            3.9482e-04, -1.5869e-02],\n",
       "          ...,\n",
       "          [-1.1426e-01,  1.3574e-01, -6.8970e-03,  ...,  2.9755e-03,\n",
       "           -4.7852e-02,  3.4180e-02],\n",
       "          [-1.3672e-01,  6.0547e-02,  2.9541e-02,  ..., -1.3916e-02,\n",
       "            1.0376e-02, -2.7710e-02],\n",
       "          [-4.1992e-02, -7.1106e-03, -2.5024e-02,  ..., -3.2715e-02,\n",
       "            3.6133e-02,  7.5684e-03]],\n",
       "\n",
       "         [[-9.8419e-04,  7.4005e-04, -9.0942e-03,  ...,  2.6550e-03,\n",
       "            8.7280e-03,  7.8735e-03],\n",
       "          [ 1.0254e-02, -1.5991e-02, -8.3618e-03,  ...,  1.5015e-02,\n",
       "           -2.7466e-03, -1.0742e-02],\n",
       "          [ 9.1553e-03,  4.3335e-03,  3.2501e-03,  ...,  7.1106e-03,\n",
       "            5.0659e-03, -7.2327e-03],\n",
       "          ...,\n",
       "          [ 5.6458e-03, -9.4604e-03, -7.2479e-04,  ...,  5.6458e-04,\n",
       "            1.0193e-02,  6.6833e-03],\n",
       "          [ 4.7607e-03, -1.1292e-02,  2.0409e-04,  ..., -6.2561e-03,\n",
       "            1.0376e-02,  6.6833e-03],\n",
       "          [-2.7466e-02, -1.4526e-02,  8.7280e-03,  ...,  1.8188e-02,\n",
       "            1.1536e-02, -2.1820e-03]],\n",
       "\n",
       "         [[-3.2471e-02, -1.3428e-02, -2.3804e-02,  ..., -8.9355e-02,\n",
       "           -9.5215e-02,  6.8359e-02],\n",
       "          [ 4.4189e-02,  1.1658e-02,  8.3984e-02,  ...,  1.2598e-01,\n",
       "            7.2754e-02, -1.1328e-01],\n",
       "          [-3.0396e-02,  1.4648e-02, -3.5889e-02,  ..., -5.0049e-02,\n",
       "           -4.3457e-02,  6.6406e-02],\n",
       "          ...,\n",
       "          [ 9.3750e-02,  5.4016e-03,  2.5146e-02,  ..., -7.9102e-02,\n",
       "            1.9531e-01, -5.1270e-02],\n",
       "          [ 3.9307e-02,  3.1128e-02,  1.5747e-02,  ...,  9.4727e-02,\n",
       "           -3.6377e-02, -2.9102e-01],\n",
       "          [-1.1597e-02,  1.5747e-02, -4.2725e-03,  ..., -1.5039e-01,\n",
       "           -7.5684e-02,  2.6367e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.9053e-02,  1.9287e-02, -1.3542e-04,  ..., -3.5400e-02,\n",
       "           -1.2695e-02, -2.0142e-02],\n",
       "          [-1.7212e-02, -4.9133e-03,  1.9897e-02,  ...,  9.0942e-03,\n",
       "           -1.7700e-02,  1.4099e-02],\n",
       "          [-2.2278e-03, -1.1444e-03,  4.5776e-03,  ...,  2.5749e-04,\n",
       "           -1.6479e-03, -1.2390e-02],\n",
       "          ...,\n",
       "          [ 2.2888e-03,  1.4191e-03,  2.4658e-02,  ..., -1.6602e-02,\n",
       "           -1.5869e-02,  4.7913e-03],\n",
       "          [-1.0315e-02, -1.0864e-02,  3.2471e-02,  ..., -2.0142e-03,\n",
       "           -1.3062e-02,  1.6113e-02],\n",
       "          [ 1.3916e-02, -8.2397e-03,  1.9836e-03,  ..., -1.0315e-02,\n",
       "            5.6641e-02, -5.0659e-03]],\n",
       "\n",
       "         [[-1.5234e-01, -2.0142e-02, -1.3379e-01,  ..., -2.8076e-03,\n",
       "            5.5542e-03, -8.9111e-03],\n",
       "          [-8.3008e-02, -6.7139e-03,  2.5195e-01,  ..., -1.4404e-02,\n",
       "            1.3123e-02,  2.9297e-03],\n",
       "          [-1.5137e-02,  2.3926e-02, -2.0020e-02,  ...,  2.2316e-04,\n",
       "           -5.8289e-03, -2.2339e-02],\n",
       "          ...,\n",
       "          [-3.3203e-02, -1.0645e-01, -4.0283e-02,  ...,  6.1646e-03,\n",
       "            4.0588e-03, -6.8970e-03],\n",
       "          [-2.5757e-02, -1.8463e-03,  1.8262e-01,  ..., -1.8311e-02,\n",
       "            3.7354e-02,  1.1719e-02],\n",
       "          [-2.1191e-01,  2.3315e-02,  6.9336e-02,  ..., -8.7280e-03,\n",
       "           -4.2114e-03,  2.2583e-03]],\n",
       "\n",
       "         [[-3.9368e-03,  1.2024e-02,  1.1658e-02,  ..., -2.2430e-03,\n",
       "            8.4229e-03,  4.3640e-03],\n",
       "          [ 2.9907e-02,  1.1230e-02, -6.7749e-03,  ...,  2.4719e-03,\n",
       "           -1.0742e-02,  2.3682e-02],\n",
       "          [-5.7678e-03, -4.9744e-03,  6.0120e-03,  ...,  1.8501e-04,\n",
       "           -3.3112e-03, -4.1199e-03],\n",
       "          ...,\n",
       "          [-2.0294e-03,  1.4709e-02,  1.1292e-02,  ..., -7.2327e-03,\n",
       "            5.0659e-03,  8.3618e-03],\n",
       "          [-7.4158e-03, -6.3171e-03,  1.2817e-02,  ...,  1.0010e-02,\n",
       "            4.2419e-03,  3.9062e-03],\n",
       "          [ 3.0670e-03,  2.4414e-02, -7.3242e-03,  ..., -7.7209e-03,\n",
       "           -2.4414e-02,  3.0518e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 7.9346e-03,  2.7710e-02,  1.5564e-02,  ..., -1.8750e-01,\n",
       "           -1.1768e-01,  1.4746e-01],\n",
       "          [-1.1250e+00, -1.2988e-01,  4.8828e-01,  ...,  6.9922e-01,\n",
       "            3.5742e-01, -3.3594e-01],\n",
       "          [-1.2266e+00,  0.0000e+00,  3.5352e-01,  ...,  2.4414e-01,\n",
       "           -7.9102e-02, -6.6016e-01],\n",
       "          ...,\n",
       "          [ 1.6562e+00,  3.4766e-01,  1.2402e-01,  ..., -5.7031e-01,\n",
       "            2.8125e-01, -2.1484e-01],\n",
       "          [ 1.1875e+00, -1.0010e-01, -1.5234e-01,  ..., -8.5938e-01,\n",
       "            1.2031e+00, -8.3203e-01],\n",
       "          [ 0.0000e+00, -4.4531e-01, -5.9375e-01,  ...,  6.6797e-01,\n",
       "            8.3594e-01, -1.0625e+00]],\n",
       "\n",
       "         [[-1.9653e-02,  2.4048e-02,  5.3223e-02,  ...,  2.4902e-02,\n",
       "            3.3398e-01, -3.4570e-01],\n",
       "          [ 3.2031e-01,  0.0000e+00,  9.2773e-02,  ..., -3.1250e-01,\n",
       "           -1.6016e+00,  1.6641e+00],\n",
       "          [-1.8164e-01, -1.5039e-01, -2.8125e-01,  ...,  3.2422e-01,\n",
       "           -2.1562e+00,  2.3438e+00],\n",
       "          ...,\n",
       "          [ 2.1875e-01,  6.1279e-02,  3.4375e-01,  ..., -8.3594e-01,\n",
       "           -2.1406e+00,  2.3906e+00],\n",
       "          [-9.2285e-02, -6.8359e-02, -1.4844e-01,  ..., -3.0078e-01,\n",
       "           -1.9062e+00,  1.8672e+00],\n",
       "          [ 4.4922e-01,  7.9297e-01, -5.2344e-01,  ..., -3.8281e-01,\n",
       "           -1.6953e+00,  1.9453e+00]],\n",
       "\n",
       "         [[-3.1128e-02,  7.4707e-02,  2.5024e-02,  ..., -8.6914e-02,\n",
       "            3.4180e-01, -3.5156e-01],\n",
       "          [ 3.0469e+00, -2.1719e+00,  4.0234e-01,  ...,  1.6953e+00,\n",
       "           -1.7734e+00,  1.7656e+00],\n",
       "          [ 1.1094e+00, -1.2969e+00,  1.3477e-01,  ...,  1.2734e+00,\n",
       "           -1.9297e+00,  2.6875e+00],\n",
       "          ...,\n",
       "          [-1.9219e+00,  1.0625e+00, -8.9844e-01,  ...,  3.2031e+00,\n",
       "           -2.6562e+00,  2.4844e+00],\n",
       "          [ 1.5625e-01,  1.2734e+00,  9.7656e-04,  ...,  1.2969e+00,\n",
       "           -9.6875e-01,  2.0312e+00],\n",
       "          [ 1.9062e+00,  6.5430e-02, -7.9590e-02,  ..., -1.2207e-01,\n",
       "           -3.0469e+00,  2.0469e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.3936e-02,  4.3640e-03,  2.0020e-02,  ...,  3.5889e-02,\n",
       "            3.7598e-02,  9.5215e-02],\n",
       "          [ 3.7500e-01,  1.2500e-01,  8.9355e-02,  ...,  2.4219e-01,\n",
       "           -5.7422e-01, -2.6758e-01],\n",
       "          [ 1.5430e-01, -4.6289e-01,  2.6758e-01,  ...,  3.8477e-01,\n",
       "           -4.7266e-01,  8.6914e-02],\n",
       "          ...,\n",
       "          [-4.3359e-01, -1.2207e-02,  1.2256e-01,  ..., -2.3145e-01,\n",
       "           -1.7773e-01, -1.6113e-02],\n",
       "          [-5.8203e-01,  1.0742e-01, -9.1309e-02,  ..., -1.1377e-01,\n",
       "            1.5625e-01, -8.8281e-01],\n",
       "          [-1.1328e-01, -1.8555e-02,  8.1787e-03,  ...,  3.6523e-01,\n",
       "           -1.5918e-01, -7.2266e-01]],\n",
       "\n",
       "         [[-7.9346e-03,  3.4424e-02,  9.8267e-03,  ..., -7.2754e-02,\n",
       "           -6.4941e-02,  3.4180e-03],\n",
       "          [-1.2969e+00,  6.2109e-01, -1.6797e-01,  ...,  3.3203e-01,\n",
       "           -4.6680e-01, -5.9375e-01],\n",
       "          [-1.4062e+00,  1.4453e+00,  7.8125e-02,  ...,  2.4414e-01,\n",
       "            7.4219e-01, -1.0781e+00],\n",
       "          ...,\n",
       "          [ 2.0156e+00, -6.2500e-01,  3.2422e-01,  ..., -2.8125e-01,\n",
       "           -3.1250e-01, -4.5703e-01],\n",
       "          [ 1.3438e+00, -2.2344e+00, -1.0938e+00,  ...,  2.5625e+00,\n",
       "            1.7031e+00,  1.9141e+00],\n",
       "          [-5.6641e-01, -1.8984e+00, -7.1484e-01,  ...,  7.3047e-01,\n",
       "            1.6328e+00, -9.7656e-02]],\n",
       "\n",
       "         [[-4.9316e-02, -2.2339e-02, -1.8311e-02,  ...,  1.0791e-01,\n",
       "            4.7266e-01,  8.6426e-02],\n",
       "          [ 1.0781e+00, -4.8438e-01, -1.0859e+00,  ..., -4.1016e-01,\n",
       "           -2.1094e+00, -2.5586e-01],\n",
       "          [-5.3906e-01, -1.1797e+00, -5.3125e-01,  ...,  3.4375e-01,\n",
       "           -2.5469e+00,  8.0566e-02],\n",
       "          ...,\n",
       "          [ 5.1562e-01,  3.9648e-01,  1.0391e+00,  ..., -1.3984e+00,\n",
       "           -2.7344e+00,  1.0156e+00],\n",
       "          [ 1.6250e+00,  6.2891e-01,  8.7500e-01,  ..., -1.6016e+00,\n",
       "           -2.5625e+00, -6.5234e-01],\n",
       "          [ 1.4531e+00,  4.7656e-01,  3.9062e-01,  ..., -7.6172e-01,\n",
       "           -2.2500e+00, -1.0859e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 3.8300e-03, -9.2316e-04,  4.5776e-03,  ...,  1.5163e-04,\n",
       "           -9.9182e-05, -1.3504e-03],\n",
       "          [-2.2949e-02,  1.5527e-01,  2.0410e-01,  ...,  7.1289e-02,\n",
       "            3.8330e-02, -5.1514e-02],\n",
       "          [-9.6680e-02, -1.1169e-02, -1.5527e-01,  ..., -5.1025e-02,\n",
       "            1.8066e-02, -2.3438e-02],\n",
       "          ...,\n",
       "          [ 4.0039e-02,  8.0566e-02,  1.3855e-02,  ...,  1.1621e-01,\n",
       "           -7.8125e-02, -2.1680e-01],\n",
       "          [-9.6191e-02, -6.1719e-01, -1.2207e-01,  ...,  3.2715e-02,\n",
       "           -1.1670e-01,  9.7168e-02],\n",
       "          [-1.6602e-02, -1.4062e-01,  2.3828e-01,  ...,  1.9043e-01,\n",
       "            3.2812e-01, -3.8672e-01]],\n",
       "\n",
       "         [[ 2.7832e-02,  7.0190e-04,  3.7384e-03,  ..., -9.2316e-04,\n",
       "           -3.1433e-03,  1.6708e-03],\n",
       "          [-9.8047e-01, -1.5137e-01, -1.6602e-01,  ..., -1.9409e-02,\n",
       "           -1.6699e-01,  1.0059e-01],\n",
       "          [-3.5742e-01, -9.7168e-02, -2.5391e-01,  ..., -2.4414e-01,\n",
       "           -6.5430e-02, -4.1748e-02],\n",
       "          ...,\n",
       "          [-2.5977e-01,  6.2012e-02, -6.0059e-02,  ..., -5.5664e-02,\n",
       "            1.0400e-01, -3.5400e-02],\n",
       "          [-7.6562e-01, -2.6953e-01, -2.6172e-01,  ...,  9.9609e-02,\n",
       "           -4.9316e-02, -1.4062e-01],\n",
       "          [-6.8359e-01,  4.6387e-02, -1.6406e-01,  ..., -2.8906e-01,\n",
       "           -6.5918e-02,  9.2285e-02]],\n",
       "\n",
       "         [[-6.3419e-05, -3.0060e-03,  5.8899e-03,  ...,  1.7776e-03,\n",
       "           -1.1597e-03, -1.2283e-03],\n",
       "          [ 1.5723e-01,  2.0703e-01, -9.7656e-02,  ..., -3.0273e-02,\n",
       "            3.2227e-02, -7.8613e-02],\n",
       "          [-7.3730e-02,  1.2109e-01,  1.3672e-01,  ...,  4.1199e-03,\n",
       "            1.1523e-01,  1.0596e-01],\n",
       "          ...,\n",
       "          [-2.1289e-01,  1.4343e-02,  9.1797e-02,  ..., -8.2031e-02,\n",
       "            2.4658e-02,  3.9551e-02],\n",
       "          [ 1.2061e-01,  4.0283e-02,  9.6680e-02,  ...,  6.0791e-02,\n",
       "            1.1963e-01,  2.4658e-02],\n",
       "          [ 2.4707e-01, -4.9316e-02, -2.7930e-01,  ..., -6.9336e-02,\n",
       "            9.5703e-02, -1.7090e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 4.4556e-03,  2.0752e-02, -4.8218e-03,  ..., -5.1880e-03,\n",
       "            1.8597e-04,  4.6692e-03],\n",
       "          [ 1.0742e-01, -2.3145e-01,  2.3633e-01,  ..., -1.0864e-02,\n",
       "           -1.5137e-01,  5.5908e-02],\n",
       "          [-6.5430e-02, -5.1270e-03,  1.6504e-01,  ..., -4.9561e-02,\n",
       "           -3.1836e-01,  6.2012e-02],\n",
       "          ...,\n",
       "          [-2.4902e-02,  1.8848e-01,  3.6865e-02,  ..., -1.9043e-01,\n",
       "           -2.1973e-02, -4.0283e-02],\n",
       "          [-3.8086e-02,  1.5430e-01,  1.3477e-01,  ..., -4.5898e-01,\n",
       "           -4.9219e-01, -6.8359e-02],\n",
       "          [ 1.6504e-01, -6.9922e-01, -1.4746e-01,  ...,  7.2266e-01,\n",
       "           -1.2354e-01,  1.0059e-01]],\n",
       "\n",
       "         [[ 4.0894e-03, -9.6436e-03,  2.0599e-03,  ...,  4.8828e-04,\n",
       "           -1.5182e-03, -2.3193e-03],\n",
       "          [-2.9297e-02,  2.0264e-02, -9.0942e-03,  ...,  1.3281e-01,\n",
       "           -1.9409e-02,  2.0605e-01],\n",
       "          [ 3.0151e-02, -3.8281e-01,  1.2283e-03,  ..., -2.0312e-01,\n",
       "           -2.5781e-01, -1.8311e-02],\n",
       "          ...,\n",
       "          [-5.3223e-02, -4.3213e-02,  8.1787e-03,  ..., -1.0300e-03,\n",
       "           -8.6426e-02, -1.0132e-02],\n",
       "          [ 1.4941e-01, -4.3359e-01,  2.4109e-03,  ..., -3.3203e-01,\n",
       "           -1.6895e-01, -4.2480e-02],\n",
       "          [ 1.6406e-01, -3.0664e-01,  1.5918e-01,  ...,  4.4189e-02,\n",
       "           -2.7148e-01, -1.9629e-01]],\n",
       "\n",
       "         [[ 2.5177e-03, -3.4637e-03, -8.0585e-05,  ...,  1.8311e-02,\n",
       "           -6.5613e-04, -2.1820e-03],\n",
       "          [ 3.7500e-01,  1.3086e-01, -7.0801e-02,  ..., -3.8477e-01,\n",
       "           -2.9297e-01,  2.8711e-01],\n",
       "          [-4.5898e-02,  9.1309e-02, -2.6172e-01,  ..., -5.4932e-02,\n",
       "           -1.3379e-01,  4.5654e-02],\n",
       "          ...,\n",
       "          [ 1.1230e-01,  1.2891e-01,  3.3984e-01,  ..., -2.7539e-01,\n",
       "           -6.4941e-02,  2.1680e-01],\n",
       "          [-2.5879e-02,  1.3086e-01, -2.9297e-01,  ..., -9.0332e-02,\n",
       "           -5.2734e-02, -9.9121e-02],\n",
       "          [-1.3770e-01,  9.5215e-03,  3.3936e-02,  ..., -9.0820e-02,\n",
       "            1.3770e-01, -1.1035e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.0681e-02, -1.8845e-03,  1.2451e-02,  ...,  5.5859e-01,\n",
       "            5.5176e-02, -5.5469e-01],\n",
       "          [-8.5156e-01,  4.8438e-01,  1.5820e-01,  ..., -2.2344e+00,\n",
       "           -1.1094e+00,  2.4688e+00],\n",
       "          [ 2.4316e-01, -4.5508e-01,  1.8164e-01,  ..., -2.3594e+00,\n",
       "           -5.9375e-01,  2.4688e+00],\n",
       "          ...,\n",
       "          [-4.1797e-01,  7.1875e-01,  4.3164e-01,  ..., -2.7812e+00,\n",
       "           -9.3750e-01,  2.8906e+00],\n",
       "          [-4.7070e-01,  1.4844e-01,  2.9297e-01,  ..., -3.2969e+00,\n",
       "            3.1406e+00,  3.4531e+00],\n",
       "          [-1.6016e+00,  5.7422e-01,  4.4141e-01,  ..., -2.9062e+00,\n",
       "           -2.2188e+00,  3.0156e+00]],\n",
       "\n",
       "         [[ 2.2736e-03,  3.4180e-02, -1.7822e-02,  ..., -4.2383e-01,\n",
       "           -4.5703e-01, -3.6523e-01],\n",
       "          [ 2.8125e-01, -1.2988e-01, -9.4727e-02,  ...,  8.9062e-01,\n",
       "            1.2578e+00,  1.2031e+00],\n",
       "          [-5.4199e-02, -7.5195e-02,  3.8574e-02,  ...,  6.4844e-01,\n",
       "            1.5938e+00,  4.8828e-01],\n",
       "          ...,\n",
       "          [ 2.7539e-01,  1.8555e-01, -3.2227e-02,  ...,  9.3750e-01,\n",
       "            1.5078e+00,  7.8125e-01],\n",
       "          [ 2.0703e-01,  5.6641e-02, -7.8125e-02,  ...,  8.0859e-01,\n",
       "            7.8516e-01,  7.9688e-01],\n",
       "          [ 1.0645e-01,  1.0742e-01, -6.3477e-02,  ...,  1.7031e+00,\n",
       "            1.4766e+00,  8.1250e-01]],\n",
       "\n",
       "         [[ 1.8799e-02, -1.5564e-02,  3.7193e-05,  ...,  1.1641e+00,\n",
       "           -1.2061e-01,  1.9824e-01],\n",
       "          [-7.1094e-01,  7.8125e-01,  6.3281e-01,  ..., -4.7812e+00,\n",
       "            8.3203e-01, -1.3516e+00],\n",
       "          [ 3.7109e-01,  3.4180e-01, -4.7461e-01,  ..., -5.4688e+00,\n",
       "            9.1406e-01, -6.9141e-01],\n",
       "          ...,\n",
       "          [ 1.6602e-02, -2.5586e-01,  4.4727e-01,  ..., -5.5625e+00,\n",
       "           -1.4844e+00, -1.4355e-01],\n",
       "          [-8.2812e-01,  4.1992e-01,  1.0156e-01,  ..., -6.1250e+00,\n",
       "           -4.8438e-01,  1.7480e-01],\n",
       "          [-1.0156e+00,  7.9688e-01,  2.0898e-01,  ..., -5.0938e+00,\n",
       "           -2.0312e-01, -1.6172e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.2012e-02, -1.1475e-02,  7.3853e-03,  ...,  8.2031e-01,\n",
       "           -8.0078e-02,  7.0312e-02],\n",
       "          [ 1.1328e+00,  5.4688e-01,  4.7461e-01,  ..., -2.3125e+00,\n",
       "           -1.6797e-01,  1.4062e+00],\n",
       "          [ 4.1797e-01,  7.6562e-01,  9.9121e-02,  ..., -3.2969e+00,\n",
       "           -2.5586e-01, -2.6367e-01],\n",
       "          ...,\n",
       "          [-4.2578e-01, -7.5000e-01, -2.7466e-02,  ..., -3.2031e+00,\n",
       "           -4.5703e-01,  7.8906e-01],\n",
       "          [ 2.4023e-01,  8.7891e-02, -2.0605e-01,  ..., -3.7969e+00,\n",
       "           -3.7891e-01, -7.3828e-01],\n",
       "          [ 5.6250e-01, -2.9492e-01, -2.8906e-01,  ..., -2.9531e+00,\n",
       "           -3.0859e-01, -4.9316e-02]],\n",
       "\n",
       "         [[ 2.5757e-02,  3.3691e-02, -3.5400e-02,  ...,  8.9844e-02,\n",
       "           -9.0820e-02, -1.3977e-02],\n",
       "          [-5.9375e-01, -4.4922e-01,  9.0234e-01,  ..., -1.2188e+00,\n",
       "           -7.6562e-01, -6.5234e-01],\n",
       "          [-1.8945e-01, -7.4219e-02,  9.6191e-02,  ..., -1.2500e+00,\n",
       "            1.6016e+00, -6.0547e-01],\n",
       "          ...,\n",
       "          [ 4.5898e-01,  5.5469e-01, -1.7700e-02,  ..., -1.7031e+00,\n",
       "            2.8320e-01,  2.0410e-01],\n",
       "          [-4.3750e-01,  1.4453e-01, -9.5703e-02,  ..., -1.4062e-01,\n",
       "            2.0469e+00,  6.9922e-01],\n",
       "          [-4.4141e-01, -8.0469e-01, -1.9141e-01,  ..., -1.2500e+00,\n",
       "            1.1641e+00,  1.7480e-01]],\n",
       "\n",
       "         [[ 2.6245e-02,  1.7456e-02, -1.0071e-03,  ...,  7.0312e-01,\n",
       "           -6.9922e-01, -5.9766e-01],\n",
       "          [-2.2852e-01,  2.2070e-01,  4.3945e-01,  ..., -4.7188e+00,\n",
       "            1.8516e+00,  9.8438e-01],\n",
       "          [-4.1016e-02,  4.3555e-01,  1.6113e-01,  ..., -3.7031e+00,\n",
       "            1.5469e+00,  1.9766e+00],\n",
       "          ...,\n",
       "          [-1.1719e-01, -7.0312e-02, -1.0742e-01,  ..., -5.0312e+00,\n",
       "            7.1562e+00,  4.4062e+00],\n",
       "          [-4.2578e-01, -2.5391e-01, -1.3672e-01,  ..., -4.6875e+00,\n",
       "            5.8125e+00,  2.2969e+00],\n",
       "          [-6.3672e-01, -3.6914e-01,  2.0020e-01,  ..., -5.8594e-01,\n",
       "            1.3359e+00,  5.5312e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 2.5940e-03, -5.1270e-03,  8.6975e-04,  ..., -4.1199e-03,\n",
       "           -1.6861e-03,  4.6387e-03],\n",
       "          [ 3.4180e-01,  6.8848e-02,  5.4688e-01,  ...,  4.3164e-01,\n",
       "           -1.9922e-01, -1.0986e-02],\n",
       "          [ 2.5000e-01,  1.2988e-01, -1.8457e-01,  ..., -2.7734e-01,\n",
       "           -5.5469e-01,  4.1016e-02],\n",
       "          ...,\n",
       "          [ 1.2109e-01, -7.1106e-03, -9.8633e-02,  ..., -3.0151e-02,\n",
       "            4.3555e-01,  2.0215e-01],\n",
       "          [ 1.0205e-01,  1.7480e-01, -1.6895e-01,  ..., -2.0605e-01,\n",
       "            2.5195e-01, -3.0396e-02],\n",
       "          [-5.7068e-03,  1.6016e-01,  5.4688e-01,  ...,  1.7188e-01,\n",
       "           -1.9922e-01,  2.0996e-01]],\n",
       "\n",
       "         [[ 7.9346e-04, -5.9128e-04, -3.6430e-04,  ..., -6.9275e-03,\n",
       "            2.8076e-03, -2.1973e-03],\n",
       "          [ 2.2656e-01,  2.5977e-01,  2.4219e-01,  ...,  1.9238e-01,\n",
       "            3.3398e-01, -4.6631e-02],\n",
       "          [-4.7607e-03,  2.4414e-01, -4.1260e-02,  ...,  2.2461e-01,\n",
       "            2.6562e-01, -2.4902e-01],\n",
       "          ...,\n",
       "          [ 6.1719e-01,  2.6953e-01,  2.8931e-02,  ...,  2.7734e-01,\n",
       "            7.8125e-02,  1.4551e-01],\n",
       "          [ 1.6602e-01,  1.7188e-01, -1.7285e-01,  ..., -4.9744e-03,\n",
       "            4.6143e-02,  6.6895e-02],\n",
       "          [-2.7344e-01, -1.4941e-01, -1.6113e-01,  ..., -9.6191e-02,\n",
       "           -4.3457e-02, -2.9492e-01]],\n",
       "\n",
       "         [[-1.0452e-03,  5.2795e-03, -1.0132e-02,  ...,  2.7466e-03,\n",
       "           -5.4626e-03,  1.9684e-03],\n",
       "          [ 1.8066e-01,  4.6289e-01,  3.0859e-01,  ..., -2.2461e-01,\n",
       "           -2.9297e-01, -2.0508e-01],\n",
       "          [ 8.8379e-02, -1.5820e-01, -4.6680e-01,  ..., -3.5156e-02,\n",
       "           -7.2327e-03,  3.9258e-01],\n",
       "          ...,\n",
       "          [ 2.5586e-01, -1.4404e-02, -4.2969e-01,  ..., -1.4160e-01,\n",
       "            3.2617e-01, -6.8359e-02],\n",
       "          [-3.0469e-01,  2.1582e-01, -7.4707e-02,  ...,  2.3926e-01,\n",
       "            1.2500e-01, -3.7109e-01],\n",
       "          [ 8.7891e-02,  1.7383e-01,  2.4414e-01,  ...,  9.5703e-02,\n",
       "            3.6719e-01,  1.3086e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.2561e-03,  4.2725e-03,  5.7983e-03,  ..., -2.5330e-03,\n",
       "           -4.6082e-03, -1.4019e-04],\n",
       "          [ 1.2354e-01,  1.5234e-01, -3.4766e-01,  ...,  3.4668e-02,\n",
       "           -1.0803e-02, -5.8594e-01],\n",
       "          [-1.0449e-01, -2.2168e-01,  1.9141e-01,  ..., -2.7148e-01,\n",
       "            1.2988e-01,  2.2266e-01],\n",
       "          ...,\n",
       "          [-7.3242e-02,  6.9922e-01, -3.5352e-01,  ...,  9.0942e-03,\n",
       "            4.7852e-01,  3.8086e-01],\n",
       "          [-2.5977e-01,  1.9922e-01,  5.6250e-01,  ...,  1.2402e-01,\n",
       "            7.0312e-02, -1.1816e-01],\n",
       "          [-2.8906e-01, -1.2598e-01,  3.5742e-01,  ...,  9.2773e-02,\n",
       "           -1.8750e-01, -5.1172e-01]],\n",
       "\n",
       "         [[-4.5776e-03, -3.7842e-03,  6.9885e-03,  ...,  1.1902e-03,\n",
       "           -5.0049e-03,  9.6130e-04],\n",
       "          [-5.9570e-02,  6.3672e-01, -5.0781e-01,  ...,  1.8066e-01,\n",
       "           -1.8359e-01, -1.5625e-02],\n",
       "          [-4.4189e-02,  3.6719e-01, -2.0117e-01,  ..., -4.2188e-01,\n",
       "            2.0703e-01,  4.4434e-02],\n",
       "          ...,\n",
       "          [ 3.9978e-03, -3.9062e-01,  4.0039e-01,  ...,  7.0312e-01,\n",
       "           -7.3730e-02, -3.7842e-02],\n",
       "          [-1.7944e-02, -3.4570e-01, -2.7539e-01,  ...,  3.5889e-02,\n",
       "           -5.3711e-02, -2.0898e-01],\n",
       "          [-3.8867e-01, -6.1279e-02, -9.5215e-02,  ...,  2.4805e-01,\n",
       "            3.2227e-01,  3.8672e-01]],\n",
       "\n",
       "         [[-2.7313e-03,  3.6621e-03,  4.6997e-03,  ..., -2.9755e-03,\n",
       "            3.4180e-03,  4.1199e-03],\n",
       "          [-1.4221e-02,  1.8066e-02, -3.8281e-01,  ...,  7.2327e-03,\n",
       "           -1.8164e-01, -1.4832e-02],\n",
       "          [-2.4902e-01, -5.2979e-02, -2.2754e-01,  ...,  5.1514e-02,\n",
       "            8.5938e-02, -6.3965e-02],\n",
       "          ...,\n",
       "          [-1.5137e-01,  2.1240e-02, -2.8229e-03,  ...,  1.8188e-02,\n",
       "            3.9551e-02,  1.2598e-01],\n",
       "          [-1.0156e-01,  2.3926e-01,  3.4570e-01,  ...,  4.4556e-03,\n",
       "            1.1328e-01, -7.0312e-02],\n",
       "          [-1.2988e-01,  1.8066e-01,  2.8198e-02,  ..., -2.0020e-02,\n",
       "           -9.1309e-02,  7.9102e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.3000e-02,  4.8828e-02,  2.6978e-02,  ...,  4.2969e-02,\n",
       "            9.7168e-02,  1.9043e-01],\n",
       "          [ 1.9219e+00,  3.1836e-01, -7.8125e-01,  ...,  9.5312e-01,\n",
       "           -9.1016e-01, -6.3672e-01],\n",
       "          [ 2.5781e-01,  4.1406e-01, -1.6699e-01,  ..., -6.8750e-01,\n",
       "            5.6152e-02,  1.2109e+00],\n",
       "          ...,\n",
       "          [-2.0703e-01, -5.2734e-02,  5.6250e-01,  ..., -1.9922e-01,\n",
       "            9.2578e-01,  1.5918e-01],\n",
       "          [ 9.2578e-01, -9.6875e-01, -2.5000e-01,  ...,  4.5508e-01,\n",
       "           -5.8984e-01, -2.2969e+00],\n",
       "          [ 1.3438e+00, -4.0430e-01,  5.6641e-02,  ..., -4.6631e-02,\n",
       "           -1.8359e-01, -4.8828e-01]],\n",
       "\n",
       "         [[ 1.1673e-03, -1.4709e-02, -8.1177e-03,  ...,  7.6172e-01,\n",
       "            2.1289e-01,  4.8633e-01],\n",
       "          [ 2.7148e-01,  3.7109e-02,  3.6719e-01,  ..., -1.4844e+00,\n",
       "           -1.0234e+00, -4.7266e-01],\n",
       "          [ 7.8613e-02,  1.3574e-01,  6.4844e-01,  ..., -1.0781e+00,\n",
       "           -1.7422e+00, -8.5938e-01],\n",
       "          ...,\n",
       "          [ 7.6660e-02,  4.7852e-02, -5.7422e-01,  ..., -1.2344e+00,\n",
       "           -6.1328e-01,  3.4961e-01],\n",
       "          [ 1.5430e-01,  4.6094e-01, -1.4062e-01,  ..., -2.2812e+00,\n",
       "            1.0625e+00, -3.2043e-03],\n",
       "          [ 1.8262e-01,  1.1328e-01,  7.5391e-01,  ..., -2.4375e+00,\n",
       "            1.7109e+00,  4.0234e-01]],\n",
       "\n",
       "         [[-8.9111e-03, -4.7852e-02,  5.7373e-02,  ..., -5.0049e-02,\n",
       "           -1.5234e-01,  1.5039e-01],\n",
       "          [-2.5391e-01, -1.6602e-01,  3.3203e-01,  ...,  4.4727e-01,\n",
       "            4.9609e-01,  1.1953e+00],\n",
       "          [ 1.2085e-02, -2.6855e-03,  1.7188e-01,  ...,  6.9580e-03,\n",
       "            7.4609e-01,  4.6484e-01],\n",
       "          ...,\n",
       "          [-9.4238e-02,  3.9307e-02,  1.0010e-02,  ...,  1.6895e-01,\n",
       "            1.1426e-01,  1.7578e-01],\n",
       "          [ 1.5820e-01,  6.5625e-01,  2.8516e-01,  ...,  7.5000e-01,\n",
       "            6.0938e-01, -1.9922e+00],\n",
       "          [-1.5723e-01,  5.5908e-02,  2.1680e-01,  ..., -7.6562e-01,\n",
       "           -1.3770e-01, -7.7344e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.8433e-02, -1.1169e-02,  2.4048e-02,  ..., -1.5332e-01,\n",
       "           -2.2852e-01, -1.6211e-01],\n",
       "          [ 3.6719e-01,  2.0410e-01,  8.2422e-01,  ...,  1.5859e+00,\n",
       "            9.3359e-01,  4.0234e-01],\n",
       "          [ 1.6504e-01,  3.6133e-01, -2.6562e-01,  ...,  1.3672e+00,\n",
       "            1.8516e+00,  7.7734e-01],\n",
       "          ...,\n",
       "          [-3.5547e-01, -4.3945e-01,  3.2031e-01,  ...,  1.4609e+00,\n",
       "            2.3906e+00,  8.2520e-02],\n",
       "          [-2.2461e-01, -5.0000e-01, -2.8809e-02,  ...,  2.7930e-01,\n",
       "            1.4141e+00, -6.2109e-01],\n",
       "          [-8.7891e-02, -4.0430e-01,  1.0156e-01,  ...,  3.8086e-01,\n",
       "            2.1406e+00, -9.6094e-01]],\n",
       "\n",
       "         [[-1.1902e-02, -3.2471e-02,  2.3682e-02,  ..., -2.6172e-01,\n",
       "           -4.2578e-01, -3.9648e-01],\n",
       "          [ 6.3281e-01, -1.1328e-01, -3.9258e-01,  ...,  1.3359e+00,\n",
       "            1.3203e+00, -4.5703e-01],\n",
       "          [-1.7871e-01, -1.0986e-03, -2.9297e-01,  ...,  7.8516e-01,\n",
       "            1.1094e+00,  3.4961e-01],\n",
       "          ...,\n",
       "          [-5.5078e-01, -5.5469e-01,  5.3906e-01,  ..., -3.8672e-01,\n",
       "            6.1328e-01, -7.6172e-01],\n",
       "          [-1.6968e-02,  3.5547e-01,  6.5234e-01,  ...,  3.1250e-01,\n",
       "            1.1016e+00,  1.4375e+00],\n",
       "          [-4.9609e-01,  1.0391e+00,  7.3047e-01,  ...,  4.7656e-01,\n",
       "            2.0000e+00,  7.1484e-01]],\n",
       "\n",
       "         [[ 1.7395e-03, -2.7832e-02,  2.7618e-03,  ..., -1.9531e-01,\n",
       "           -4.5703e-01, -1.3086e-01],\n",
       "          [-1.1562e+00,  1.2793e-01, -8.9355e-02,  ...,  9.1797e-01,\n",
       "            1.8047e+00, -1.8828e+00],\n",
       "          [ 2.3926e-01,  2.3633e-01, -4.6387e-02,  ...,  9.2969e-01,\n",
       "            2.2812e+00, -8.9062e-01],\n",
       "          ...,\n",
       "          [-3.7500e-01, -8.7109e-01, -3.4570e-01,  ..., -3.6094e+00,\n",
       "            1.0312e+00, -1.5078e+00],\n",
       "          [-5.6641e-01, -5.1562e-01,  7.7344e-01,  ..., -6.9922e-01,\n",
       "            2.0938e+00, -9.0234e-01],\n",
       "          [-9.2969e-01, -5.6250e-01, -7.3242e-03,  ..., -3.1641e-01,\n",
       "            1.0078e+00,  1.8262e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 2.4261e-03, -2.0752e-03,  2.5482e-03,  ...,  5.4016e-03,\n",
       "           -1.1520e-03, -2.9602e-03],\n",
       "          [ 9.0820e-02, -2.1387e-01,  2.3730e-01,  ..., -3.0859e-01,\n",
       "            2.6367e-01, -7.9102e-02],\n",
       "          [ 1.2891e-01,  1.8066e-01,  3.5742e-01,  ..., -1.1426e-01,\n",
       "           -1.3965e-01,  1.3379e-01],\n",
       "          ...,\n",
       "          [ 1.4062e-01,  6.6895e-02, -3.5938e-01,  ...,  8.7402e-02,\n",
       "           -2.0020e-01,  1.1169e-02],\n",
       "          [-4.3213e-02, -9.8145e-02, -3.6133e-01,  ..., -2.2095e-02,\n",
       "           -8.9355e-02,  4.1016e-01],\n",
       "          [-2.8906e-01, -5.0964e-03, -1.6406e-01,  ..., -1.7676e-01,\n",
       "           -2.0410e-01,  2.5977e-01]],\n",
       "\n",
       "         [[-3.2959e-03, -8.6060e-03,  6.6223e-03,  ...,  7.4387e-04,\n",
       "            6.7444e-03,  5.7068e-03],\n",
       "          [-3.3398e-01,  4.7461e-01, -3.1445e-01,  ..., -1.1279e-01,\n",
       "            2.2363e-01,  5.1172e-01],\n",
       "          [-4.2969e-01,  3.7305e-01, -1.2695e-01,  ...,  2.3438e-01,\n",
       "           -6.9824e-02,  9.9609e-02],\n",
       "          ...,\n",
       "          [-1.3477e-01, -9.1797e-02,  6.7383e-02,  ...,  3.2422e-01,\n",
       "            3.8086e-01, -3.4961e-01],\n",
       "          [-6.8848e-02, -3.4180e-01,  3.3691e-02,  ..., -4.8633e-01,\n",
       "            1.4746e-01, -1.9531e-01],\n",
       "          [ 2.0996e-01, -1.0254e-01, -2.9663e-02,  ..., -3.6133e-01,\n",
       "            1.4746e-01,  3.0469e-01]],\n",
       "\n",
       "         [[ 1.4114e-03,  5.0659e-03,  1.0757e-03,  ..., -1.9836e-03,\n",
       "           -3.2349e-03,  1.7822e-02],\n",
       "          [ 1.3672e-01,  2.9688e-01,  2.7539e-01,  ..., -3.1836e-01,\n",
       "           -9.4238e-02, -8.6426e-02],\n",
       "          [-1.0645e-01, -1.8848e-01,  6.2891e-01,  ...,  7.4609e-01,\n",
       "           -2.3340e-01, -1.1292e-02],\n",
       "          ...,\n",
       "          [-2.5977e-01, -5.7068e-03, -3.0859e-01,  ...,  3.9258e-01,\n",
       "            4.7363e-02, -2.7734e-01],\n",
       "          [-9.4727e-02,  2.7734e-01, -1.2402e-01,  ..., -5.7031e-01,\n",
       "            2.1875e-01, -9.4141e-01],\n",
       "          [ 2.4512e-01, -2.8125e-01, -9.0332e-02,  ..., -6.6406e-02,\n",
       "            2.1777e-01, -7.0312e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 4.0588e-03, -2.1210e-03, -1.0452e-03,  ...,  3.2501e-03,\n",
       "            3.0518e-03, -3.0670e-03],\n",
       "          [-3.9844e-01,  6.7188e-01, -2.3315e-02,  ...,  4.8828e-01,\n",
       "            2.2852e-01,  1.1230e-01],\n",
       "          [-1.4062e-01,  4.1260e-02,  1.2158e-01,  ...,  1.0547e-01,\n",
       "           -9.4727e-02,  1.4453e-01],\n",
       "          ...,\n",
       "          [-5.1562e-01,  2.4023e-01,  5.5859e-01,  ...,  3.7842e-02,\n",
       "            1.0156e-01,  7.1289e-02],\n",
       "          [-4.3359e-01,  2.5195e-01, -8.3008e-02,  ..., -2.4707e-01,\n",
       "            1.5234e-01,  3.8672e-01],\n",
       "          [-4.8242e-01,  1.7578e-01, -1.7773e-01,  ..., -1.5625e-01,\n",
       "            6.2256e-03,  2.3340e-01]],\n",
       "\n",
       "         [[-3.3722e-03,  1.4801e-03, -6.0730e-03,  ..., -2.1606e-02,\n",
       "            7.7438e-04, -6.7520e-04],\n",
       "          [-1.8082e-03,  7.0312e-02,  1.6895e-01,  ...,  2.0996e-01,\n",
       "           -4.6631e-02, -1.4453e-01],\n",
       "          [ 1.1670e-01, -7.7148e-02, -4.8584e-02,  ...,  2.8516e-01,\n",
       "           -1.0791e-01,  1.5625e-01],\n",
       "          ...,\n",
       "          [ 1.5430e-01,  1.9531e-01, -2.0117e-01,  ..., -1.2256e-01,\n",
       "           -2.8711e-01,  1.2012e-01],\n",
       "          [ 2.7930e-01,  1.3574e-01, -3.9551e-02,  ..., -6.7871e-02,\n",
       "           -3.0859e-01,  3.4790e-03],\n",
       "          [ 1.4453e-01,  3.0273e-01,  3.3008e-01,  ...,  2.6562e-01,\n",
       "           -2.1484e-01,  2.0142e-02]],\n",
       "\n",
       "         [[ 7.0496e-03, -1.9897e-02, -5.4932e-03,  ...,  2.5330e-03,\n",
       "           -5.9509e-03, -5.0659e-03],\n",
       "          [-3.3984e-01, -1.3379e-01,  9.0820e-02,  ...,  8.9844e-02,\n",
       "            6.3965e-02, -2.6172e-01],\n",
       "          [-3.4570e-01,  3.9844e-01, -1.1377e-01,  ..., -4.1016e-02,\n",
       "            2.0215e-01, -1.8262e-01],\n",
       "          ...,\n",
       "          [-7.1289e-02,  2.5000e-01,  6.6757e-04,  ...,  3.5352e-01,\n",
       "            1.6968e-02, -9.6680e-02],\n",
       "          [ 1.4746e-01, -3.5156e-02, -2.3071e-02,  ..., -9.4727e-02,\n",
       "           -5.1514e-02,  5.8838e-02],\n",
       "          [ 1.7285e-01, -1.6992e-01,  1.3367e-02,  ...,  1.9727e-01,\n",
       "           -1.1816e-01, -1.5564e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-4.6875e-02, -6.3324e-04, -2.8076e-03,  ...,  1.9043e-02,\n",
       "           -2.1777e-01,  1.0620e-02],\n",
       "          [-7.2266e-02, -2.1484e-02,  2.3730e-01,  ...,  6.7188e-01,\n",
       "            1.2734e+00, -2.5586e-01],\n",
       "          [-3.7500e-01, -5.6641e-01,  4.4141e-01,  ..., -2.2949e-01,\n",
       "            2.2969e+00, -3.3984e-01],\n",
       "          ...,\n",
       "          [ 4.4531e-01,  3.4180e-01, -6.6406e-02,  ..., -2.4316e-01,\n",
       "            1.5547e+00,  1.0625e+00],\n",
       "          [ 2.1094e-01,  1.0469e+00, -6.3281e-01,  ..., -1.7812e+00,\n",
       "            8.4375e-01,  2.1387e-01],\n",
       "          [-2.8516e-01,  2.1875e-01, -7.2656e-01,  ..., -3.8672e-01,\n",
       "            6.1328e-01,  9.4141e-01]],\n",
       "\n",
       "         [[-8.1177e-03,  3.0273e-02, -2.4658e-02,  ...,  4.2114e-03,\n",
       "           -1.3916e-02, -1.4893e-02],\n",
       "          [ 6.6406e-02, -5.5469e-01,  1.6602e-02,  ...,  1.3281e+00,\n",
       "           -1.1328e+00,  7.4609e-01],\n",
       "          [ 1.0693e-01, -1.0059e-01, -1.4062e-01,  ..., -5.0781e-01,\n",
       "           -8.0859e-01,  1.2734e+00],\n",
       "          ...,\n",
       "          [-2.4707e-01,  5.3125e-01, -7.6660e-02,  ...,  4.4727e-01,\n",
       "            1.2500e+00,  1.9297e+00],\n",
       "          [ 2.0020e-01,  3.0664e-01,  2.5977e-01,  ...,  6.5234e-01,\n",
       "            7.7344e-01,  1.0547e+00],\n",
       "          [ 1.6406e-01,  2.0703e-01,  9.9121e-02,  ...,  5.9375e-01,\n",
       "            9.2969e-01,  1.0781e+00]],\n",
       "\n",
       "         [[ 5.3406e-03,  1.8311e-02, -5.4932e-03,  ..., -1.8750e-01,\n",
       "            3.6914e-01, -8.5156e-01],\n",
       "          [-6.2500e-01, -1.3281e-01,  5.3906e-01,  ...,  5.8594e-01,\n",
       "           -4.8828e-01,  2.1406e+00],\n",
       "          [ 6.0547e-01, -1.5625e+00, -4.8828e-01,  ...,  1.6094e+00,\n",
       "           -4.1406e-01,  2.8906e+00],\n",
       "          ...,\n",
       "          [-9.5703e-01,  1.2266e+00, -1.4453e-01,  ...,  1.7422e+00,\n",
       "           -1.9766e+00,  3.1875e+00],\n",
       "          [-1.0156e+00,  6.2500e-01,  4.1797e-01,  ...,  4.2773e-01,\n",
       "           -8.4375e-01,  3.2188e+00],\n",
       "          [ 1.2031e+00, -9.8633e-02,  8.5547e-01,  ...,  9.6094e-01,\n",
       "           -9.8828e-01,  2.7500e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1973e-02, -1.4709e-02,  1.4404e-02,  ..., -4.3359e-01,\n",
       "           -4.1016e-01,  7.4219e-02],\n",
       "          [-2.7930e-01,  2.2266e-01,  9.9121e-02,  ..., -7.8906e-01,\n",
       "            1.6953e+00,  2.0801e-01],\n",
       "          [-3.4668e-02,  9.2773e-02, -4.1992e-01,  ...,  1.2812e+00,\n",
       "            2.3438e+00,  4.5898e-02],\n",
       "          ...,\n",
       "          [ 2.5000e-01, -4.8584e-02, -4.3555e-01,  ...,  1.8906e+00,\n",
       "            1.2266e+00, -3.8086e-01],\n",
       "          [ 2.3242e-01, -5.4297e-01, -8.0078e-01,  ...,  1.0234e+00,\n",
       "            1.4922e+00, -1.2812e+00],\n",
       "          [-2.4707e-01, -1.0889e-01, -1.1426e-01,  ...,  2.9688e+00,\n",
       "            1.1172e+00, -1.2734e+00]],\n",
       "\n",
       "         [[-3.7842e-02,  2.0020e-02,  1.9150e-03,  ..., -2.1680e-01,\n",
       "           -1.0840e-01, -1.2012e-01],\n",
       "          [ 3.1250e-02, -8.0469e-01,  3.7500e-01,  ..., -6.4453e-02,\n",
       "           -1.8750e-01,  9.2578e-01],\n",
       "          [ 5.7031e-01,  3.3594e-01,  1.7188e-01,  ...,  2.1094e+00,\n",
       "           -3.4668e-02,  8.2422e-01],\n",
       "          ...,\n",
       "          [-3.8477e-01, -2.5586e-01,  3.5742e-01,  ...,  1.7656e+00,\n",
       "            1.2812e+00, -4.0625e-01],\n",
       "          [-1.7285e-01, -5.8594e-03, -5.1172e-01,  ...,  1.5547e+00,\n",
       "            1.5859e+00, -7.3438e-01],\n",
       "          [ 5.8594e-02,  2.4902e-02,  2.6758e-01,  ...,  1.1016e+00,\n",
       "            9.8438e-01, -1.2734e+00]],\n",
       "\n",
       "         [[ 3.2227e-02, -4.3030e-03,  3.6812e-04,  ...,  8.8672e-01,\n",
       "           -1.1768e-01,  7.4609e-01],\n",
       "          [ 3.7109e-01,  1.3379e-01, -1.3672e-01,  ..., -1.1172e+00,\n",
       "            2.2031e+00, -2.6250e+00],\n",
       "          [ 3.8672e-01,  5.6641e-02,  5.2979e-02,  ..., -1.9297e+00,\n",
       "            1.8125e+00, -2.6406e+00],\n",
       "          ...,\n",
       "          [-5.8594e-01, -7.0312e-02, -3.9258e-01,  ..., -2.0469e+00,\n",
       "            5.3516e-01, -2.6719e+00],\n",
       "          [-5.0391e-01,  1.6094e+00, -4.5312e-01,  ..., -3.1406e+00,\n",
       "           -1.2266e+00, -2.5625e+00],\n",
       "          [ 2.2461e-01,  6.3672e-01, -3.6133e-01,  ..., -2.4375e+00,\n",
       "            9.0234e-01, -3.2812e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-5.0964e-03, -2.4658e-02, -3.8910e-03,  ...,  9.3994e-03,\n",
       "            1.3733e-02, -1.1841e-02],\n",
       "          [ 2.1191e-01,  5.8203e-01,  2.9688e-01,  ...,  2.2070e-01,\n",
       "           -2.3560e-02,  5.3125e-01],\n",
       "          [-1.7871e-01, -1.0864e-02,  1.9043e-01,  ..., -2.2583e-02,\n",
       "           -1.0059e-01, -1.8262e-01],\n",
       "          ...,\n",
       "          [-3.2031e-01,  3.2812e-01, -5.5859e-01,  ..., -1.9531e-01,\n",
       "           -3.6523e-01, -3.1641e-01],\n",
       "          [-6.9141e-01,  1.4258e-01, -5.6641e-01,  ...,  8.9844e-02,\n",
       "            2.6172e-01, -7.6562e-01],\n",
       "          [-8.6328e-01,  1.3770e-01, -5.1562e-01,  ...,  2.2266e-01,\n",
       "            5.3516e-01, -1.0625e+00]],\n",
       "\n",
       "         [[-2.1362e-03,  1.2695e-02,  1.6327e-03,  ..., -1.0620e-02,\n",
       "           -8.4839e-03, -3.6621e-03],\n",
       "          [ 4.8047e-01, -2.4707e-01,  2.6172e-01,  ...,  3.3789e-01,\n",
       "            5.8203e-01, -6.7383e-02],\n",
       "          [ 1.7676e-01,  1.0938e-01,  1.7383e-01,  ..., -4.6143e-02,\n",
       "           -5.2979e-02, -2.4536e-02],\n",
       "          ...,\n",
       "          [-6.5234e-01,  1.2085e-02,  1.2500e-01,  ...,  2.7344e-01,\n",
       "           -1.9824e-01,  2.5586e-01],\n",
       "          [ 1.7578e-01, -2.2461e-01, -1.4355e-01,  ...,  2.1973e-01,\n",
       "           -3.0859e-01,  7.1289e-02],\n",
       "          [ 4.6094e-01, -5.8203e-01, -4.4531e-01,  ..., -4.3359e-01,\n",
       "            2.0410e-01,  1.0400e-01]],\n",
       "\n",
       "         [[ 2.8992e-03, -1.3809e-03, -1.1780e-02,  ...,  1.3809e-03,\n",
       "            1.9897e-02, -5.8289e-03],\n",
       "          [ 4.4727e-01, -2.9297e-01, -7.0312e-01,  ..., -7.8613e-02,\n",
       "            2.7148e-01,  1.0303e-01],\n",
       "          [-4.7070e-01,  2.3828e-01, -3.9258e-01,  ..., -3.8672e-01,\n",
       "           -2.7466e-02,  5.0049e-03],\n",
       "          ...,\n",
       "          [ 7.2266e-02,  5.8203e-01,  1.7822e-02,  ..., -8.5938e-01,\n",
       "           -4.1211e-01, -3.3789e-01],\n",
       "          [-1.4941e-01,  2.0312e-01, -3.6328e-01,  ..., -4.4922e-01,\n",
       "            7.3242e-02, -5.3125e-01],\n",
       "          [ 2.0801e-01,  4.4922e-01,  2.0215e-01,  ..., -5.8594e-01,\n",
       "            2.8516e-01, -1.7871e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.3994e-03, -1.1673e-03,  7.5378e-03,  ..., -2.5330e-03,\n",
       "           -7.4463e-03, -2.7847e-04],\n",
       "          [ 1.0498e-01, -5.0049e-02,  4.4141e-01,  ..., -1.9824e-01,\n",
       "           -1.0693e-01,  4.0039e-01],\n",
       "          [-3.5889e-02, -2.5195e-01,  3.1250e-01,  ..., -8.6426e-02,\n",
       "            5.7812e-01,  3.2617e-01],\n",
       "          ...,\n",
       "          [ 1.9336e-01, -4.1797e-01,  1.9824e-01,  ...,  2.5781e-01,\n",
       "            4.1992e-01, -6.2891e-01],\n",
       "          [ 5.8594e-01,  1.8164e-01, -2.0020e-01,  ..., -1.5723e-01,\n",
       "            1.6309e-01, -1.0742e-01],\n",
       "          [ 5.4297e-01, -1.7871e-01, -1.4355e-01,  ...,  2.5586e-01,\n",
       "            3.1433e-03, -2.2656e-01]],\n",
       "\n",
       "         [[-2.0447e-03,  3.0975e-03, -2.0905e-03,  ..., -3.9978e-03,\n",
       "           -5.3787e-04,  6.1340e-03],\n",
       "          [-3.5742e-01,  4.9072e-02,  2.6172e-01,  ...,  1.3867e-01,\n",
       "            1.7188e-01, -1.6992e-01],\n",
       "          [-4.1992e-01,  1.4258e-01,  6.0156e-01,  ...,  1.9727e-01,\n",
       "            1.1865e-01,  1.3574e-01],\n",
       "          ...,\n",
       "          [ 1.8750e-01,  3.7354e-02,  4.7266e-01,  ...,  3.5547e-01,\n",
       "           -4.3164e-01, -6.2256e-02],\n",
       "          [ 1.0352e-01, -3.6719e-01,  2.2363e-01,  ...,  2.4902e-01,\n",
       "           -1.6992e-01, -2.6758e-01],\n",
       "          [ 1.0376e-02, -4.0625e-01,  3.8281e-01,  ...,  2.4512e-01,\n",
       "            2.8442e-02,  2.3145e-01]],\n",
       "\n",
       "         [[-4.5166e-03, -8.8501e-03,  1.1780e-02,  ...,  4.6997e-03,\n",
       "           -1.4954e-03,  1.5030e-03],\n",
       "          [ 2.5000e-01,  1.2061e-01, -8.5449e-02,  ..., -1.2598e-01,\n",
       "            2.7734e-01, -3.3203e-01],\n",
       "          [ 1.1865e-01,  1.3086e-01, -1.2256e-01,  ...,  1.1035e-01,\n",
       "           -1.2451e-02, -9.4727e-02],\n",
       "          ...,\n",
       "          [-1.5332e-01,  1.6211e-01,  1.7090e-01,  ...,  1.8262e-01,\n",
       "           -2.0898e-01,  8.1055e-02],\n",
       "          [ 9.4727e-02,  1.0059e-01,  2.2461e-01,  ...,  4.7070e-01,\n",
       "           -3.1641e-01,  1.2158e-01],\n",
       "          [-1.0059e-01,  2.5781e-01, -3.3398e-01,  ...,  2.5781e-01,\n",
       "           -2.6172e-01,  7.0801e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-9.2773e-03, -1.3809e-03,  1.3580e-03,  ..., -1.2256e-01,\n",
       "           -2.7344e-01,  5.3125e-01],\n",
       "          [-2.4023e-01, -7.0312e-01,  2.4219e-01,  ...,  8.1641e-01,\n",
       "            8.5938e-01, -1.6641e+00],\n",
       "          [-6.5625e-01, -7.8516e-01,  1.0156e+00,  ...,  2.0781e+00,\n",
       "           -1.9434e-01, -1.9297e+00],\n",
       "          ...,\n",
       "          [-9.3750e-02,  5.1172e-01, -1.7676e-01,  ...,  5.4688e-01,\n",
       "           -9.6094e-01, -2.3125e+00],\n",
       "          [-1.2109e-01,  2.9297e-03, -1.6602e-02,  ..., -4.6094e-01,\n",
       "            1.0859e+00, -2.1406e+00],\n",
       "          [ 5.3125e-01, -5.2344e-01,  6.4844e-01,  ...,  9.4141e-01,\n",
       "           -4.1211e-01, -2.0781e+00]],\n",
       "\n",
       "         [[-6.3171e-03,  1.9409e-02, -9.7656e-03,  ...,  1.2656e+00,\n",
       "           -2.7930e-01, -3.4961e-01],\n",
       "          [ 2.4023e-01,  6.4062e-01,  4.7266e-01,  ..., -2.2500e+00,\n",
       "            1.1484e+00,  6.5234e-01],\n",
       "          [ 2.4512e-01,  1.4355e-01,  5.4932e-02,  ..., -2.0469e+00,\n",
       "            1.2188e+00,  8.5938e-01],\n",
       "          ...,\n",
       "          [-1.6797e-01, -4.1211e-01, -5.9375e-01,  ..., -3.6250e+00,\n",
       "            6.7969e-01, -7.2266e-01],\n",
       "          [-9.1406e-01, -7.8906e-01, -7.8516e-01,  ..., -4.7188e+00,\n",
       "           -2.0703e-01, -2.6094e+00],\n",
       "          [-4.6680e-01,  6.3477e-02, -4.4727e-01,  ..., -4.2188e+00,\n",
       "            9.2773e-02, -2.5312e+00]],\n",
       "\n",
       "         [[ 3.4912e-02, -4.1199e-03,  4.1992e-02,  ...,  6.0059e-02,\n",
       "           -5.9204e-03, -5.5908e-02],\n",
       "          [ 3.1055e-01,  2.1191e-01, -4.9609e-01,  ..., -3.1055e-01,\n",
       "            1.2188e+00,  1.0625e+00],\n",
       "          [ 3.4961e-01,  2.6367e-01, -7.4707e-02,  ...,  8.4961e-02,\n",
       "            1.1328e+00,  9.1016e-01],\n",
       "          ...,\n",
       "          [ 3.2227e-01,  1.8262e-01,  2.5000e-01,  ..., -5.1953e-01,\n",
       "            1.4766e+00,  5.4297e-01],\n",
       "          [-1.1914e-01,  7.0703e-01,  4.1992e-01,  ..., -5.9766e-01,\n",
       "            5.7031e-01,  6.2500e-01],\n",
       "          [-5.1562e-01, -9.7656e-04,  8.0078e-02,  ...,  1.0234e+00,\n",
       "            2.4062e+00, -3.9978e-03]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5747e-02,  4.2969e-02,  1.1063e-03,  ...,  1.5723e-01,\n",
       "            1.4160e-01, -1.0840e-01],\n",
       "          [ 3.2031e-01,  2.4023e-01,  1.7656e+00,  ...,  1.6484e+00,\n",
       "           -1.2109e+00,  1.3047e+00],\n",
       "          [ 3.6133e-01,  5.8594e-03,  5.9375e-01,  ...,  2.0625e+00,\n",
       "           -7.3828e-01, -7.1484e-01],\n",
       "          ...,\n",
       "          [-1.6602e-02,  3.6133e-01, -1.4648e-02,  ..., -1.5430e-01,\n",
       "           -3.9062e-01, -8.5547e-01],\n",
       "          [-2.5391e-02, -1.0078e+00, -8.6328e-01,  ...,  4.3701e-02,\n",
       "            3.1055e-01, -1.3047e+00],\n",
       "          [ 1.5137e-01,  8.7891e-03, -1.6602e-01,  ...,  3.2422e-01,\n",
       "           -2.2949e-01, -4.6875e-01]],\n",
       "\n",
       "         [[-3.6865e-02, -2.2949e-02,  1.0010e-02,  ..., -2.1289e-01,\n",
       "            3.1055e-01,  5.8203e-01],\n",
       "          [-1.7090e-01, -9.3262e-02,  1.9434e-01,  ..., -8.0566e-02,\n",
       "           -1.0400e-01, -2.0156e+00],\n",
       "          [ 3.3203e-02,  3.2227e-01,  2.4219e-01,  ...,  1.6094e+00,\n",
       "           -8.9453e-01, -4.4062e+00],\n",
       "          ...,\n",
       "          [-1.9922e-01,  7.1289e-02, -2.9297e-01,  ...,  1.9609e+00,\n",
       "           -1.1016e+00, -2.3438e+00],\n",
       "          [-7.9102e-02,  1.3965e-01, -3.7842e-02,  ...,  1.8594e+00,\n",
       "           -7.2656e-01, -2.6406e+00],\n",
       "          [ 4.8828e-04,  4.0820e-01,  3.4570e-01,  ...,  4.7852e-01,\n",
       "            1.9609e+00, -6.2891e-01]],\n",
       "\n",
       "         [[ 1.6846e-02,  1.1902e-02,  3.7842e-02,  ..., -2.3633e-01,\n",
       "           -1.2207e-01,  8.3496e-02],\n",
       "          [ 3.4668e-02,  1.5430e-01, -5.2734e-02,  ...,  7.1484e-01,\n",
       "           -1.1016e+00,  8.8379e-02],\n",
       "          [-3.1836e-01, -4.0234e-01, -5.6641e-01,  ...,  1.7734e+00,\n",
       "            1.7285e-01, -1.8262e-01],\n",
       "          ...,\n",
       "          [-7.8613e-02, -2.8076e-02,  3.0859e-01,  ...,  1.3594e+00,\n",
       "            2.0312e+00, -1.6094e+00],\n",
       "          [ 1.2305e-01, -5.8594e-01, -4.1016e-01,  ...,  2.9102e-01,\n",
       "            1.8164e-01, -6.1328e-01],\n",
       "          [-5.7031e-01, -1.7969e-01, -7.2266e-01,  ..., -2.1484e-01,\n",
       "            9.8047e-01, -8.6328e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-9.2163e-03, -2.2583e-02, -3.1891e-03,  ..., -1.5717e-03,\n",
       "           -4.9133e-03, -8.3618e-03],\n",
       "          [-7.4219e-01,  6.7969e-01, -1.4062e-01,  ..., -7.8516e-01,\n",
       "            6.7188e-01,  3.5938e-01],\n",
       "          [ 1.7285e-01,  3.3594e-01,  4.0820e-01,  ...,  1.5039e-01,\n",
       "            1.4551e-01,  2.6367e-01],\n",
       "          ...,\n",
       "          [-9.1309e-02,  4.3555e-01,  6.9824e-02,  ...,  1.0400e-01,\n",
       "            3.3188e-04, -4.1602e-01],\n",
       "          [-2.9883e-01, -2.4121e-01, -1.7773e-01,  ...,  1.0107e-01,\n",
       "            5.1172e-01,  2.9688e-01],\n",
       "          [ 2.5391e-01, -3.8086e-01,  2.8906e-01,  ..., -4.5898e-01,\n",
       "           -1.1426e-01,  3.8086e-01]],\n",
       "\n",
       "         [[ 1.6113e-02, -1.3351e-03,  2.8839e-03,  ...,  9.3384e-03,\n",
       "            2.6245e-03, -2.3651e-03],\n",
       "          [-5.1172e-01,  3.8330e-02,  1.4404e-02,  ..., -2.0801e-01,\n",
       "            1.4746e-01, -2.6172e-01],\n",
       "          [-1.4844e-01,  4.5898e-01, -3.6719e-01,  ..., -3.8867e-01,\n",
       "           -1.0645e-01, -2.3560e-02],\n",
       "          ...,\n",
       "          [ 1.9434e-01,  3.4570e-01,  7.9346e-03,  ..., -1.6113e-01,\n",
       "            1.5430e-01,  6.0938e-01],\n",
       "          [ 1.7480e-01,  8.2422e-01, -1.2012e-01,  ..., -4.1602e-01,\n",
       "            2.0801e-01,  2.3438e-01],\n",
       "          [-3.0078e-01,  1.0859e+00, -3.4570e-01,  ..., -6.0547e-01,\n",
       "           -4.9072e-02, -1.3184e-01]],\n",
       "\n",
       "         [[-5.6763e-03, -2.0752e-02,  6.7139e-03,  ...,  9.6191e-02,\n",
       "           -4.3106e-04, -6.2256e-03],\n",
       "          [ 3.7109e-01, -3.8086e-01,  2.1875e-01,  ..., -1.2188e+00,\n",
       "            2.3828e-01,  1.7578e-02],\n",
       "          [-2.7539e-01,  2.2754e-01, -4.9609e-01,  ..., -8.8281e-01,\n",
       "            5.8594e-01, -6.7578e-01],\n",
       "          ...,\n",
       "          [-1.8652e-01, -9.0332e-02, -1.5332e-01,  ..., -6.9922e-01,\n",
       "            6.7871e-02,  2.1973e-03],\n",
       "          [ 3.6621e-02, -5.9570e-02, -2.2949e-01,  ..., -1.9141e-01,\n",
       "           -3.5938e-01, -4.7461e-01],\n",
       "          [-5.0781e-01, -5.3906e-01, -3.5156e-01,  ..., -5.0391e-01,\n",
       "           -7.0801e-02, -2.5269e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-5.6458e-03, -1.0620e-02,  2.9449e-03,  ..., -1.9989e-03,\n",
       "            3.6163e-03,  2.2888e-03],\n",
       "          [ 6.6895e-02, -3.9258e-01,  2.8516e-01,  ...,  2.8125e-01,\n",
       "           -6.7969e-01,  5.5420e-02],\n",
       "          [ 3.4375e-01, -1.7090e-01,  4.6875e-01,  ...,  4.5312e-01,\n",
       "           -6.2109e-01, -7.5781e-01],\n",
       "          ...,\n",
       "          [ 2.8125e-01,  6.0547e-01,  2.6367e-01,  ...,  1.7285e-01,\n",
       "           -5.5469e-01, -2.1191e-01],\n",
       "          [ 1.4453e-01,  3.5352e-01,  2.0142e-02,  ...,  4.5898e-01,\n",
       "           -4.8633e-01,  3.6133e-02],\n",
       "          [ 4.4336e-01,  8.0566e-02,  7.0312e-02,  ...,  2.2949e-01,\n",
       "           -3.2031e-01,  2.5781e-01]],\n",
       "\n",
       "         [[-1.1658e-02,  5.7678e-03, -1.0300e-04,  ..., -9.8877e-03,\n",
       "           -1.0742e-02,  1.4191e-03],\n",
       "          [ 1.7578e-01, -4.3701e-02,  1.4844e-01,  ...,  2.7008e-03,\n",
       "            4.3164e-01,  1.3086e-01],\n",
       "          [ 7.4219e-02, -2.1875e-01,  3.1494e-02,  ...,  4.2383e-01,\n",
       "            2.5482e-03,  1.1816e-01],\n",
       "          ...,\n",
       "          [ 1.2354e-01,  7.1777e-02,  1.0693e-01,  ...,  7.2656e-01,\n",
       "            2.4121e-01, -2.4219e-01],\n",
       "          [ 8.4839e-03,  2.3438e-01,  2.7344e-01,  ...,  6.9531e-01,\n",
       "            1.1475e-01, -2.5000e-01],\n",
       "          [ 1.6602e-01, -1.0986e-01,  3.2812e-01,  ...,  1.3086e-01,\n",
       "            9.6191e-02, -3.6523e-01]],\n",
       "\n",
       "         [[ 1.9531e-03, -3.3417e-03, -5.4016e-03,  ...,  3.5553e-03,\n",
       "           -7.2327e-03, -6.0654e-04],\n",
       "          [ 2.7734e-01, -9.6436e-03, -3.7109e-01,  ...,  4.4141e-01,\n",
       "           -4.4922e-01, -3.1641e-01],\n",
       "          [-2.3340e-01,  2.9297e-01, -4.9072e-02,  ..., -3.5938e-01,\n",
       "           -2.6758e-01, -4.7461e-01],\n",
       "          ...,\n",
       "          [ 6.5234e-01, -7.0312e-02, -1.6403e-03,  ..., -7.3730e-02,\n",
       "            2.7930e-01, -2.3828e-01],\n",
       "          [ 5.4297e-01, -7.2266e-02,  7.4707e-02,  ..., -2.4805e-01,\n",
       "           -1.6797e-01, -9.9121e-02],\n",
       "          [ 4.3359e-01,  2.2754e-01,  6.4453e-01,  ...,  4.1260e-02,\n",
       "           -6.9922e-01,  7.6660e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.1494e-02,  7.8735e-03,  3.3203e-02,  ...,  8.3496e-02,\n",
       "            2.0996e-01,  4.3457e-02],\n",
       "          [ 4.0234e-01, -2.5781e-01, -5.4688e-02,  ..., -3.3984e-01,\n",
       "            1.9922e+00, -9.8438e-01],\n",
       "          [ 1.8848e-01, -2.8711e-01, -3.1250e-01,  ...,  5.9766e-01,\n",
       "            9.9609e-01, -2.8750e+00],\n",
       "          ...,\n",
       "          [-2.7539e-01, -7.9102e-02,  4.2969e-01,  ..., -3.4766e-01,\n",
       "            6.2891e-01,  1.8457e-01],\n",
       "          [ 1.8799e-02,  2.0312e-01,  4.1260e-02,  ...,  1.2451e-01,\n",
       "            8.8281e-01,  3.7891e-01],\n",
       "          [-9.9609e-02,  2.3047e-01,  1.4355e-01,  ...,  5.1953e-01,\n",
       "           -2.8320e-01,  2.4609e-01]],\n",
       "\n",
       "         [[-1.3428e-02,  5.6458e-03,  5.6152e-03,  ..., -1.6309e-01,\n",
       "           -1.2360e-03,  1.5381e-02],\n",
       "          [-1.3477e-01,  2.2168e-01,  7.1484e-01,  ..., -1.0303e-01,\n",
       "           -1.1562e+00,  5.1172e-01],\n",
       "          [ 6.2500e-01, -5.3516e-01,  6.5625e-01,  ..., -1.0859e+00,\n",
       "           -9.4141e-01, -3.8867e-01],\n",
       "          ...,\n",
       "          [-8.2812e-01, -6.2012e-02, -1.1523e-01,  ...,  9.3750e-01,\n",
       "           -1.1182e-01, -1.1641e+00],\n",
       "          [ 1.0547e-01, -1.9336e-01, -6.7578e-01,  ...,  8.0859e-01,\n",
       "           -1.4062e+00, -1.4609e+00],\n",
       "          [ 9.2578e-01, -2.1484e-01,  2.5195e-01,  ...,  4.2383e-01,\n",
       "           -6.8359e-01, -1.3516e+00]],\n",
       "\n",
       "         [[ 4.3213e-02, -3.7109e-02, -1.7578e-02,  ..., -1.0254e-01,\n",
       "           -7.1289e-02, -7.9590e-02],\n",
       "          [-2.1973e-03,  1.9629e-01, -8.7402e-02,  ...,  2.3750e+00,\n",
       "            2.3535e-01, -1.9219e+00],\n",
       "          [-5.6250e-01, -8.1543e-02,  7.8516e-01,  ...,  1.1875e+00,\n",
       "           -4.1016e-01,  9.6680e-02],\n",
       "          ...,\n",
       "          [ 5.5859e-01, -2.4902e-01, -2.1875e-01,  ..., -8.9453e-01,\n",
       "            9.1406e-01, -1.4453e+00],\n",
       "          [ 7.7734e-01, -1.3828e+00, -7.4609e-01,  ..., -5.8203e-01,\n",
       "            5.3516e-01, -2.9883e-01],\n",
       "          [-1.1523e-01, -1.7578e-02,  1.1719e-02,  ..., -3.4180e-01,\n",
       "            1.6094e+00, -5.8594e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.2773e-03,  2.5024e-02, -3.7689e-03,  ..., -1.7188e-01,\n",
       "           -1.2354e-01, -8.4229e-03],\n",
       "          [-2.1289e-01, -9.2773e-02, -1.5527e-01,  ...,  6.6406e-01,\n",
       "            1.8203e+00, -2.1094e-01],\n",
       "          [-1.6992e-01,  1.5039e-01,  4.0820e-01,  ..., -9.0942e-03,\n",
       "            1.5078e+00,  1.2598e-01],\n",
       "          ...,\n",
       "          [ 1.6113e-01, -4.5898e-02,  1.2031e+00,  ..., -9.4531e-01,\n",
       "            1.4453e+00, -6.7188e-01],\n",
       "          [-1.7773e-01,  1.2354e-01,  2.5781e-01,  ..., -4.7656e-01,\n",
       "            1.2656e+00,  9.7656e-02],\n",
       "          [ 4.5508e-01, -5.8203e-01, -2.8711e-01,  ..., -1.8047e+00,\n",
       "            4.0430e-01,  1.2109e-01]],\n",
       "\n",
       "         [[-4.8218e-03,  2.9755e-03,  5.4443e-02,  ..., -7.0312e-01,\n",
       "           -8.6060e-03,  4.9805e-02],\n",
       "          [-2.9785e-02, -9.7656e-03, -3.8867e-01,  ...,  3.3594e+00,\n",
       "           -4.0234e-01,  1.5332e-01],\n",
       "          [ 4.5703e-01,  5.4688e-01, -4.2188e-01,  ...,  3.3750e+00,\n",
       "            2.3730e-01, -1.7383e-01],\n",
       "          ...,\n",
       "          [ 4.1406e-01, -3.3789e-01, -3.7109e-01,  ...,  4.1562e+00,\n",
       "            1.0596e-01, -1.7031e+00],\n",
       "          [ 8.4766e-01, -4.2383e-01, -5.2344e-01,  ...,  4.9062e+00,\n",
       "           -7.1094e-01, -1.9844e+00],\n",
       "          [ 6.4062e-01, -5.7031e-01, -1.4062e-01,  ...,  3.3438e+00,\n",
       "           -1.8203e+00, -1.4160e-01]],\n",
       "\n",
       "         [[-7.3547e-03, -3.0273e-02,  1.8555e-02,  ..., -1.6562e+00,\n",
       "            9.0820e-02,  3.6133e-02],\n",
       "          [ 2.2168e-01,  9.2773e-02,  4.3359e-01,  ...,  3.8750e+00,\n",
       "            1.7266e+00,  8.3203e-01],\n",
       "          [ 2.4414e-02,  1.5625e-01,  2.7734e-01,  ...,  4.9062e+00,\n",
       "            2.3125e+00,  1.5547e+00],\n",
       "          ...,\n",
       "          [ 2.0996e-01, -3.6523e-01, -9.1406e-01,  ...,  5.2812e+00,\n",
       "            5.8203e-01,  5.1953e-01],\n",
       "          [-5.3125e-01, -2.2266e-01, -7.0312e-01,  ...,  5.1562e+00,\n",
       "            7.6562e-01, -3.5469e+00],\n",
       "          [ 2.2266e-01, -1.5625e-01, -1.5723e-01,  ...,  4.1875e+00,\n",
       "            3.1055e-01, -1.6562e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-5.7861e-02,  3.9978e-03,  3.6774e-03,  ...,  7.1106e-03,\n",
       "           -1.4877e-03,  3.5400e-02],\n",
       "          [ 3.7891e-01, -2.3438e-01, -1.9238e-01,  ...,  3.0078e-01,\n",
       "           -7.9590e-02,  1.1426e-01],\n",
       "          [ 7.6660e-02,  1.9922e-01,  6.1328e-01,  ..., -3.8281e-01,\n",
       "            9.7168e-02,  6.3281e-01],\n",
       "          ...,\n",
       "          [-2.6367e-01, -8.0078e-02,  2.1851e-02,  ..., -1.6113e-01,\n",
       "           -1.6113e-01, -2.2754e-01],\n",
       "          [-5.2344e-01, -1.8164e-01, -4.7070e-01,  ...,  1.0840e-01,\n",
       "           -3.5156e-02, -1.7456e-02],\n",
       "          [-9.4531e-01, -2.9883e-01, -1.2266e+00,  ..., -7.0703e-01,\n",
       "            2.4219e-01,  1.4062e-01]],\n",
       "\n",
       "         [[-5.0354e-03, -1.3489e-02, -7.5073e-03,  ..., -2.2827e-02,\n",
       "            1.2207e-04, -6.0120e-03],\n",
       "          [ 2.0605e-01, -1.9727e-01, -2.9492e-01,  ...,  1.3672e-01,\n",
       "           -7.6953e-01,  3.7956e-04],\n",
       "          [ 3.7500e-01,  7.9590e-02, -5.5908e-02,  ..., -9.0820e-02,\n",
       "           -5.5859e-01, -1.0352e-01],\n",
       "          ...,\n",
       "          [ 5.4688e-01, -1.3203e+00, -4.8438e-01,  ...,  1.4062e-01,\n",
       "            3.1250e-01, -3.0365e-03],\n",
       "          [ 3.0664e-01, -3.3594e-01, -5.8984e-01,  ...,  1.4062e-01,\n",
       "            5.8594e-01,  1.2354e-01],\n",
       "          [-1.5137e-01, -2.3926e-01, -7.9688e-01,  ..., -4.4189e-02,\n",
       "            3.8281e-01, -2.2363e-01]],\n",
       "\n",
       "         [[ 6.2561e-03,  7.7515e-03,  4.7302e-03,  ...,  1.3428e-02,\n",
       "           -6.7139e-03, -2.3193e-03],\n",
       "          [ 3.3203e-02, -4.7070e-01,  5.3711e-02,  ..., -2.1289e-01,\n",
       "           -3.5645e-02,  1.6211e-01],\n",
       "          [-1.6113e-01, -2.6562e-01,  3.0469e-01,  ..., -1.1108e-02,\n",
       "            9.2773e-02,  2.5195e-01],\n",
       "          ...,\n",
       "          [ 5.1953e-01,  8.6328e-01, -8.9355e-02,  ..., -7.4609e-01,\n",
       "            4.9805e-01,  5.2979e-02],\n",
       "          [ 7.3438e-01,  6.5625e-01, -1.4258e-01,  ..., -8.0469e-01,\n",
       "            3.9453e-01,  2.8906e-01],\n",
       "          [ 5.4297e-01,  2.2656e-01, -3.1055e-01,  ..., -1.0469e+00,\n",
       "            3.3203e-01,  7.4219e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.5635e-03,  1.6098e-03, -5.2261e-04,  ...,  4.8523e-03,\n",
       "            9.6436e-03, -3.7956e-04],\n",
       "          [ 1.6699e-01,  2.3926e-01, -9.7656e-02,  ...,  1.1084e-01,\n",
       "            4.0039e-01,  4.3750e-01],\n",
       "          [ 1.5137e-01,  9.3262e-02,  7.1289e-02,  ...,  1.5747e-02,\n",
       "            7.1777e-02,  4.3701e-02],\n",
       "          ...,\n",
       "          [-3.2617e-01, -1.1426e-01, -9.1797e-02,  ...,  4.9023e-01,\n",
       "            8.0859e-01, -3.7500e-01],\n",
       "          [-3.6719e-01, -1.5625e-01, -2.6953e-01,  ...,  1.1230e-01,\n",
       "            8.6719e-01, -7.4707e-02],\n",
       "          [-4.0430e-01,  1.8262e-01, -2.0215e-01,  ..., -1.0400e-01,\n",
       "           -1.3672e-01,  2.6367e-02]],\n",
       "\n",
       "         [[-4.7607e-03, -2.7771e-03,  7.5684e-03,  ...,  9.6512e-04,\n",
       "            3.9978e-03, -8.8501e-03],\n",
       "          [ 2.9688e-01, -2.2070e-01,  2.1777e-01,  ...,  8.1250e-01,\n",
       "            1.7871e-01,  2.0312e-01],\n",
       "          [-4.2578e-01, -3.8281e-01, -5.5859e-01,  ..., -4.0234e-01,\n",
       "            9.6680e-02, -3.9453e-01],\n",
       "          ...,\n",
       "          [ 1.8262e-01,  3.6719e-01, -7.1289e-02,  ...,  1.6992e-01,\n",
       "           -5.7031e-01,  1.9043e-01],\n",
       "          [ 4.9744e-03, -2.8809e-02, -7.6562e-01,  ..., -1.9629e-01,\n",
       "           -3.4961e-01,  1.9531e-01],\n",
       "          [ 3.9844e-01, -3.0859e-01, -7.4707e-02,  ..., -5.1758e-02,\n",
       "           -3.4570e-01, -4.3945e-01]],\n",
       "\n",
       "         [[-4.6387e-03, -1.4114e-03, -7.2327e-03,  ..., -8.8501e-03,\n",
       "            1.7929e-03, -3.5667e-04],\n",
       "          [-6.4941e-02,  5.8105e-02,  2.5391e-01,  ..., -1.4648e-01,\n",
       "           -1.5527e-01,  9.0332e-03],\n",
       "          [-1.7383e-01, -2.5977e-01,  4.9414e-01,  ...,  1.4355e-01,\n",
       "           -1.3184e-01,  9.9609e-02],\n",
       "          ...,\n",
       "          [-8.5156e-01, -1.2354e-01,  7.4707e-02,  ...,  4.8242e-01,\n",
       "            4.0820e-01,  4.3164e-01],\n",
       "          [-3.8477e-01, -9.5703e-02,  3.0664e-01,  ...,  3.9648e-01,\n",
       "            3.1836e-01,  5.3906e-01],\n",
       "          [-3.6133e-01,  1.3281e-01,  2.0312e-01,  ...,  1.0840e-01,\n",
       "            1.4893e-02,  1.6895e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.3977e-02,  5.7983e-03, -6.4392e-03,  ..., -1.5234e-01,\n",
       "            3.3203e-01,  5.4297e-01],\n",
       "          [-5.3906e-01,  9.3750e-02, -1.6699e-01,  ..., -2.8320e-01,\n",
       "           -1.3828e+00, -1.9629e-01],\n",
       "          [-1.5820e-01,  4.1016e-01,  1.5234e-01,  ..., -4.4336e-01,\n",
       "           -1.2793e-01,  4.0234e-01],\n",
       "          ...,\n",
       "          [ 8.3984e-02, -4.1406e-01,  1.5625e-02,  ..., -2.9492e-01,\n",
       "           -5.5859e-01, -1.5000e+00],\n",
       "          [-8.1543e-02, -1.9336e-01,  8.3496e-02,  ...,  5.2344e-01,\n",
       "           -7.5781e-01, -1.3438e+00],\n",
       "          [-3.5742e-01, -2.2754e-01,  1.9922e-01,  ..., -2.9144e-03,\n",
       "           -8.0078e-01, -2.0000e+00]],\n",
       "\n",
       "         [[ 1.3123e-02, -1.6479e-02,  2.7588e-02,  ..., -1.1523e-01,\n",
       "           -1.3867e-01,  4.7461e-01],\n",
       "          [ 1.0303e-01, -1.4746e-01, -6.9922e-01,  ...,  2.8125e-01,\n",
       "           -7.8516e-01, -4.6289e-01],\n",
       "          [-2.9297e-01,  1.2500e-01,  7.8613e-02,  ..., -1.8555e-01,\n",
       "            5.9375e-01, -8.2031e-01],\n",
       "          ...,\n",
       "          [ 2.5195e-01, -7.4219e-01,  7.3438e-01,  ..., -7.8906e-01,\n",
       "           -1.2109e-01,  2.9688e+00],\n",
       "          [-4.3359e-01, -7.0312e-01,  2.8516e-01,  ..., -2.3730e-01,\n",
       "            1.9043e-01,  1.5469e+00],\n",
       "          [-7.5391e-01,  8.3984e-02, -2.6562e-01,  ...,  4.8242e-01,\n",
       "            5.4297e-01,  8.0859e-01]],\n",
       "\n",
       "         [[ 1.7578e-02,  5.4932e-03, -3.3691e-02,  ..., -1.3672e-01,\n",
       "            3.1641e-01,  2.6953e-01],\n",
       "          [ 6.8750e-01,  3.0078e-01, -1.1914e-01,  ...,  7.2656e-01,\n",
       "           -7.7734e-01, -5.8203e-01],\n",
       "          [-3.4766e-01,  5.8203e-01, -1.2354e-01,  ..., -5.3516e-01,\n",
       "           -1.0312e+00, -9.0332e-02],\n",
       "          ...,\n",
       "          [ 4.4531e-01,  5.7031e-01,  2.1680e-01,  ..., -1.9727e-01,\n",
       "           -1.3281e+00, -1.3477e-01],\n",
       "          [-7.2266e-02,  9.8633e-02,  6.9531e-01,  ..., -1.4844e-01,\n",
       "           -2.4170e-02,  1.3672e+00],\n",
       "          [-3.0859e-01,  2.8711e-01,  8.5938e-01,  ..., -3.3398e-01,\n",
       "           -7.4609e-01,  1.9375e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.7994e-03,  2.1851e-02,  1.4465e-02,  ...,  1.3184e-01,\n",
       "           -2.0996e-01, -1.4941e-01],\n",
       "          [-1.6406e-01, -8.4961e-02, -3.7891e-01,  ..., -1.5991e-02,\n",
       "           -2.3438e-02,  7.0703e-01],\n",
       "          [ 1.6602e-01,  3.0469e-01, -1.5723e-01,  ...,  1.5547e+00,\n",
       "            1.0859e+00, -1.6022e-04],\n",
       "          ...,\n",
       "          [-5.3125e-01, -3.6523e-01, -3.3594e-01,  ...,  1.1953e+00,\n",
       "           -5.8203e-01, -6.7969e-01],\n",
       "          [-9.1406e-01, -3.8672e-01,  3.3984e-01,  ...,  1.3828e+00,\n",
       "            1.4062e-01, -4.1406e-01],\n",
       "          [ 1.2109e-01, -6.1328e-01, -2.7734e-01,  ...,  1.0986e-02,\n",
       "            1.1621e-01,  1.4453e+00]],\n",
       "\n",
       "         [[-6.1646e-03,  2.4872e-03, -2.7618e-03,  ...,  1.8066e-02,\n",
       "           -1.2256e-01, -3.5352e-01],\n",
       "          [-1.3672e-01, -1.3574e-01, -2.4609e-01,  ...,  6.2891e-01,\n",
       "           -6.5625e-01,  6.2891e-01],\n",
       "          [ 2.6758e-01,  4.2578e-01, -9.4727e-02,  ...,  9.2969e-01,\n",
       "            1.1406e+00,  3.6133e-01],\n",
       "          ...,\n",
       "          [ 4.9805e-02,  2.3145e-01,  4.0820e-01,  ...,  1.0625e+00,\n",
       "            1.3359e+00,  7.1484e-01],\n",
       "          [ 1.5332e-01,  2.5977e-01,  3.3203e-01,  ...,  7.4609e-01,\n",
       "            1.2500e+00,  2.6094e+00],\n",
       "          [-6.4453e-02,  6.5625e-01,  4.7266e-01,  ...,  7.8516e-01,\n",
       "           -5.1758e-02,  3.3594e+00]],\n",
       "\n",
       "         [[-2.1118e-02,  2.1118e-02,  1.5625e-02,  ...,  1.6562e+00,\n",
       "           -1.3574e-01, -2.0264e-02],\n",
       "          [ 2.5000e-01, -1.1035e-01,  5.0391e-01,  ..., -3.7969e+00,\n",
       "           -2.4512e-01, -1.2266e+00],\n",
       "          [ 1.3965e-01,  6.2988e-02,  2.0898e-01,  ..., -5.4375e+00,\n",
       "            1.6094e+00, -2.7812e+00],\n",
       "          ...,\n",
       "          [ 4.1992e-01, -1.2061e-01, -2.2266e-01,  ..., -4.6562e+00,\n",
       "           -5.0781e-01, -1.2812e+00],\n",
       "          [ 4.1406e-01, -4.0430e-01,  2.8906e-01,  ..., -3.6406e+00,\n",
       "           -8.9453e-01, -2.2188e+00],\n",
       "          [-6.9336e-02, -2.9492e-01, -2.1875e-01,  ..., -3.5312e+00,\n",
       "           -2.2031e+00,  1.1279e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.0681e-02,  4.0283e-03,  7.9956e-03,  ...,  1.1047e-02,\n",
       "            3.2501e-03,  5.3101e-03],\n",
       "          [ 4.0283e-02,  2.2705e-02,  9.3750e-02,  ..., -6.5625e-01,\n",
       "            7.4707e-02, -2.5781e-01],\n",
       "          [-2.3828e-01,  1.9629e-01,  2.2070e-01,  ...,  3.8086e-02,\n",
       "            2.7344e-01, -8.2812e-01],\n",
       "          ...,\n",
       "          [-1.3867e-01, -1.3086e-01, -1.2305e-01,  ...,  2.4316e-01,\n",
       "            3.3447e-02, -4.8828e-01],\n",
       "          [-2.0312e-01,  4.4531e-01,  3.2031e-01,  ...,  5.1172e-01,\n",
       "            3.1836e-01, -6.1719e-01],\n",
       "          [-3.3398e-01, -2.1680e-01,  6.2109e-01,  ...,  1.8848e-01,\n",
       "           -4.8828e-01, -2.4121e-01]],\n",
       "\n",
       "         [[-1.9287e-02,  2.5635e-02,  1.2268e-02,  ..., -1.2891e-01,\n",
       "            1.7166e-03,  3.0518e-03],\n",
       "          [ 2.7539e-01, -5.8838e-02,  1.5820e-01,  ..., -4.6875e-01,\n",
       "           -1.7480e-01,  1.2598e-01],\n",
       "          [-5.7861e-02, -2.1094e-01,  3.6865e-02,  ..., -5.6641e-01,\n",
       "           -1.9336e-01,  2.6245e-02],\n",
       "          ...,\n",
       "          [-9.2773e-02,  3.7109e-01, -3.0396e-02,  ...,  1.3281e+00,\n",
       "            1.5039e-01,  2.2852e-01],\n",
       "          [ 4.4336e-01,  2.4316e-01, -1.3867e-01,  ...,  8.5547e-01,\n",
       "            3.8867e-01,  4.3750e-01],\n",
       "          [ 1.3245e-02, -2.1729e-02,  2.8906e-01,  ...,  1.0781e+00,\n",
       "            2.8564e-02, -1.0498e-01]],\n",
       "\n",
       "         [[-1.0010e-02, -5.5237e-03, -1.1063e-03,  ...,  1.5076e-02,\n",
       "            1.3977e-02, -1.8188e-02],\n",
       "          [-1.3965e-01,  1.4746e-01, -1.6211e-01,  ..., -6.8750e-01,\n",
       "           -3.0664e-01,  3.1445e-01],\n",
       "          [ 4.6094e-01,  4.0039e-01, -1.2500e-01,  ..., -3.6719e-01,\n",
       "           -3.6133e-01, -8.8867e-02],\n",
       "          ...,\n",
       "          [-1.7773e-01,  9.0820e-02,  3.4570e-01,  ..., -2.5195e-01,\n",
       "            1.3965e-01,  4.3555e-01],\n",
       "          [-2.1777e-01, -3.4570e-01, -1.0547e-01,  ..., -1.0742e-01,\n",
       "           -7.8125e-03,  3.3722e-03],\n",
       "          [ 2.5977e-01, -3.9258e-01, -2.0264e-02,  ...,  2.6953e-01,\n",
       "            3.6523e-01, -1.2146e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.5858e-03, -2.7832e-02,  2.1729e-02,  ...,  2.7344e-02,\n",
       "           -5.6152e-03, -7.0496e-03],\n",
       "          [-9.7046e-03,  1.2695e-01,  2.0801e-01,  ...,  2.7734e-01,\n",
       "            3.1055e-01,  4.2188e-01],\n",
       "          [-1.3574e-01,  4.8096e-02,  8.7891e-01,  ..., -1.0010e-01,\n",
       "            1.7773e-01, -2.0801e-01],\n",
       "          ...,\n",
       "          [ 1.5332e-01, -2.9688e-01,  1.0449e-01,  ..., -2.0020e-01,\n",
       "           -3.1445e-01, -7.9956e-03],\n",
       "          [ 9.7168e-02, -1.2793e-01, -1.0791e-01,  ..., -3.7305e-01,\n",
       "           -1.8457e-01, -6.5918e-02],\n",
       "          [-3.9551e-02,  8.0566e-03, -1.6406e-01,  ...,  1.6113e-01,\n",
       "           -5.8203e-01,  3.0273e-01]],\n",
       "\n",
       "         [[-1.5488e-03, -1.1826e-03,  1.0498e-02,  ...,  2.3682e-02,\n",
       "           -1.5137e-02, -1.0498e-02],\n",
       "          [-9.2773e-02,  7.5391e-01, -1.3962e-03,  ...,  1.6357e-02,\n",
       "           -5.4297e-01, -8.5449e-02],\n",
       "          [ 3.5547e-01,  1.1172e+00, -2.9883e-01,  ...,  9.7168e-02,\n",
       "            1.9434e-01, -3.1445e-01],\n",
       "          ...,\n",
       "          [-5.9814e-02, -2.0996e-01, -8.8867e-02,  ...,  2.7539e-01,\n",
       "           -3.5645e-02,  2.7930e-01],\n",
       "          [ 8.2397e-03, -1.2158e-01, -1.2598e-01,  ..., -3.1006e-02,\n",
       "            3.1494e-02,  1.3477e-01],\n",
       "          [ 3.3398e-01,  1.7676e-01,  2.7148e-01,  ..., -7.3828e-01,\n",
       "           -2.5586e-01, -1.8652e-01]],\n",
       "\n",
       "         [[ 5.0049e-03,  7.4158e-03,  5.5695e-04,  ...,  5.5847e-03,\n",
       "           -9.3994e-03,  8.7891e-03],\n",
       "          [ 2.8711e-01, -2.0996e-01, -1.9141e-01,  ..., -8.2520e-02,\n",
       "            4.0283e-02, -6.1523e-02],\n",
       "          [-2.8711e-01,  2.0605e-01, -5.2734e-02,  ..., -2.0752e-02,\n",
       "            1.4355e-01,  1.5332e-01],\n",
       "          ...,\n",
       "          [-3.5156e-01,  1.2695e-01, -7.2266e-01,  ..., -1.2451e-01,\n",
       "           -9.6191e-02,  1.0449e-01],\n",
       "          [ 1.3977e-02,  1.9043e-01, -2.2559e-01,  ...,  2.4707e-01,\n",
       "           -2.0117e-01, -6.8359e-02],\n",
       "          [ 3.7305e-01,  4.0771e-02, -3.6133e-01,  ..., -2.7148e-01,\n",
       "           -5.7812e-01,  4.2578e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.9165e-02, -4.3030e-03,  4.9744e-03,  ..., -1.0400e-01,\n",
       "            4.5117e-01,  1.6719e+00],\n",
       "          [-5.1172e-01,  2.6562e-01, -1.0400e-01,  ...,  2.5000e+00,\n",
       "            7.2021e-03, -6.5000e+00],\n",
       "          [-6.8359e-01,  4.1211e-01, -5.6641e-02,  ...,  2.2969e+00,\n",
       "           -1.5312e+00, -7.4062e+00],\n",
       "          ...,\n",
       "          [ 6.4453e-01, -7.9688e-01, -5.1514e-02,  ...,  4.0625e+00,\n",
       "            1.9922e+00, -5.6562e+00],\n",
       "          [-5.7617e-02, -2.9883e-01, -1.2695e-01,  ...,  1.4922e+00,\n",
       "            5.8984e-01, -7.3125e+00],\n",
       "          [-5.8984e-01, -2.7148e-01, -5.0781e-01,  ...,  5.2734e-01,\n",
       "            1.1406e+00, -5.9375e+00]],\n",
       "\n",
       "         [[-1.8311e-02,  1.3062e-02, -9.5215e-03,  ..., -3.4766e-01,\n",
       "            2.1484e-01,  2.1387e-01],\n",
       "          [ 3.0859e-01,  2.4805e-01, -1.9727e-01,  ..., -2.5391e-01,\n",
       "            6.0938e-01,  8.9453e-01],\n",
       "          [ 2.0801e-01,  1.1621e-01,  1.0156e+00,  ..., -3.1836e-01,\n",
       "           -2.2344e+00,  6.6797e-01],\n",
       "          ...,\n",
       "          [ 7.5391e-01,  3.6328e-01,  1.5000e+00,  ...,  2.8711e-01,\n",
       "           -4.5898e-01, -1.3516e+00],\n",
       "          [ 4.5312e-01,  1.5234e-01,  6.2109e-01,  ...,  7.8516e-01,\n",
       "            1.9824e-01, -4.3750e-01],\n",
       "          [ 3.0664e-01, -3.5742e-01,  4.2578e-01,  ...,  5.3516e-01,\n",
       "           -1.4531e+00, -8.3594e-01]],\n",
       "\n",
       "         [[ 1.5320e-02,  2.0630e-02, -2.2339e-02,  ..., -7.9688e-01,\n",
       "            3.7695e-01,  3.9844e-01],\n",
       "          [-9.2188e-01, -3.9844e-01, -1.0791e-01,  ...,  2.3281e+00,\n",
       "           -2.2559e-01, -2.2031e+00],\n",
       "          [-5.3906e-01, -5.8203e-01,  5.5078e-01,  ...,  1.0469e+00,\n",
       "           -1.3984e+00, -3.6406e+00],\n",
       "          ...,\n",
       "          [ 2.6953e-01,  1.2500e+00, -1.2578e+00,  ...,  2.3145e-01,\n",
       "           -1.1875e+00, -2.9844e+00],\n",
       "          [ 5.1953e-01, -1.9824e-01, -8.3984e-01,  ...,  1.0469e+00,\n",
       "           -7.2266e-01, -1.0156e+00],\n",
       "          [-1.0596e-01,  7.6953e-01,  7.3047e-01,  ...,  2.2339e-02,\n",
       "           -3.2344e+00, -1.4609e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 8.1177e-03, -1.3000e-02, -8.4229e-03,  ..., -2.2583e-02,\n",
       "            2.2949e-01, -1.1641e+00],\n",
       "          [ 1.6797e-01, -1.5820e-01,  1.0352e-01,  ...,  7.5391e-01,\n",
       "            1.5547e+00,  2.3906e+00],\n",
       "          [-3.0469e-01,  6.1328e-01,  5.1172e-01,  ...,  1.0437e-02,\n",
       "            1.4922e+00,  2.5312e+00],\n",
       "          ...,\n",
       "          [-1.4648e-03, -5.9766e-01, -1.3574e-01,  ...,  7.8516e-01,\n",
       "            9.6680e-02,  4.2188e+00],\n",
       "          [ 2.3828e-01,  6.8054e-03, -5.5469e-01,  ...,  1.9062e+00,\n",
       "           -1.7500e+00,  4.3125e+00],\n",
       "          [ 1.9531e-01, -6.3672e-01,  4.5312e-01,  ...,  2.2031e+00,\n",
       "           -1.7188e+00,  4.4062e+00]],\n",
       "\n",
       "         [[-4.1260e-02, -2.7222e-02, -7.7248e-05,  ...,  4.3555e-01,\n",
       "            4.0234e-01, -1.7578e-01],\n",
       "          [ 3.3984e-01,  3.9258e-01,  1.0742e-01,  ..., -1.0000e+00,\n",
       "            1.2500e+00,  6.7578e-01],\n",
       "          [-7.5195e-02,  1.2578e+00,  5.7422e-01,  ..., -3.1406e+00,\n",
       "            1.7031e+00, -3.4766e-01],\n",
       "          ...,\n",
       "          [-3.2617e-01,  7.7734e-01,  5.8594e-02,  ...,  3.4531e+00,\n",
       "            8.6914e-02,  1.7656e+00],\n",
       "          [-6.6406e-01,  3.0078e-01, -9.2188e-01,  ...,  2.4844e+00,\n",
       "           -1.7773e-01,  4.8633e-01],\n",
       "          [ 5.8984e-01,  3.7305e-01, -5.1562e-01,  ...,  4.7188e+00,\n",
       "           -1.7969e+00,  9.1797e-02]],\n",
       "\n",
       "         [[-6.5002e-03,  1.3489e-02,  6.8283e-04,  ..., -7.2266e-01,\n",
       "           -4.4922e-01,  1.1182e-01],\n",
       "          [ 5.3125e-01, -2.7344e-01,  3.1055e-01,  ..., -7.7734e-01,\n",
       "            1.6641e+00, -3.8867e-01],\n",
       "          [ 3.1836e-01,  1.6602e-01,  1.8945e-01,  ...,  2.4805e-01,\n",
       "            1.6094e+00, -1.7422e+00],\n",
       "          ...,\n",
       "          [ 2.0020e-02,  1.8262e-01, -1.0059e-01,  ..., -6.8359e-02,\n",
       "            1.6406e+00, -7.7734e-01],\n",
       "          [ 1.9238e-01,  3.9062e-03, -1.2305e-01,  ...,  3.8477e-01,\n",
       "           -8.2812e-01, -1.4922e+00],\n",
       "          [ 1.8359e-01, -4.5703e-01, -4.6289e-01,  ..., -3.1055e-01,\n",
       "            8.6328e-01, -5.8984e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 3.4637e-03, -1.0803e-02, -3.2196e-03,  ...,  4.7913e-03,\n",
       "           -4.5166e-03, -3.9978e-03],\n",
       "          [-1.6895e-01,  7.5781e-01,  6.5234e-01,  ...,  1.6699e-01,\n",
       "            7.1875e-01,  9.0625e-01],\n",
       "          [-3.3789e-01,  2.5977e-01,  5.6250e-01,  ..., -3.0640e-02,\n",
       "            4.5117e-01,  6.6797e-01],\n",
       "          ...,\n",
       "          [ 6.7969e-01, -6.5918e-02,  6.2891e-01,  ..., -2.1191e-01,\n",
       "            1.0078e+00,  2.4658e-02],\n",
       "          [ 2.8906e-01, -1.9043e-01,  1.4844e-01,  ...,  9.3750e-02,\n",
       "            9.4531e-01,  1.8750e-01],\n",
       "          [ 5.8594e-01, -6.6016e-01, -2.7344e-01,  ...,  1.7676e-01,\n",
       "            6.0547e-01, -1.0400e-01]],\n",
       "\n",
       "         [[ 1.1902e-02,  5.5908e-02,  1.4771e-02,  ...,  4.0283e-02,\n",
       "           -4.5654e-02, -2.1484e-02],\n",
       "          [-5.3906e-01, -1.0234e+00, -3.9453e-01,  ..., -2.7539e-01,\n",
       "           -2.4316e-01, -1.0254e-01],\n",
       "          [-1.6699e-01, -7.5391e-01, -5.5859e-01,  ..., -6.6016e-01,\n",
       "           -3.8477e-01,  3.5156e-02],\n",
       "          ...,\n",
       "          [-3.7109e-01, -3.3008e-01, -2.9883e-01,  ..., -1.0547e-01,\n",
       "           -2.8711e-01,  4.4141e-01],\n",
       "          [-2.5781e-01, -2.5977e-01,  7.7637e-02,  ...,  5.1172e-01,\n",
       "           -2.9883e-01,  8.3984e-02],\n",
       "          [-3.9062e-01,  1.0352e-01,  3.7305e-01,  ...,  2.3145e-01,\n",
       "            2.1680e-01,  1.9922e-01]],\n",
       "\n",
       "         [[-4.4861e-03,  2.0630e-02, -2.0386e-02,  ...,  4.5166e-03,\n",
       "            1.0071e-02,  5.8289e-03],\n",
       "          [-4.8828e-02, -1.0596e-01,  9.1797e-02,  ..., -1.7188e-01,\n",
       "           -2.1094e-01,  1.6113e-01],\n",
       "          [-2.5977e-01, -9.0332e-03, -7.3730e-02,  ..., -5.1514e-02,\n",
       "           -4.7852e-01,  1.4551e-01],\n",
       "          ...,\n",
       "          [-1.4551e-01,  3.4375e-01,  3.0273e-01,  ..., -3.0469e-01,\n",
       "           -8.2031e-01, -4.8242e-01],\n",
       "          [-1.7188e-01,  2.9297e-02,  2.4219e-01,  ..., -1.7188e-01,\n",
       "           -4.3750e-01, -8.7402e-02],\n",
       "          [ 1.3574e-01,  3.7305e-01, -4.5654e-02,  ..., -1.3245e-02,\n",
       "           -8.8672e-01,  3.2031e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 8.9722e-03,  8.1177e-03,  1.3123e-02,  ...,  1.3245e-02,\n",
       "           -6.2256e-03,  1.1169e-02],\n",
       "          [-8.6426e-02, -2.7344e-01, -8.3008e-02,  ...,  2.5586e-01,\n",
       "           -2.4023e-01, -2.4707e-01],\n",
       "          [ 9.1797e-02, -6.4453e-01, -1.7090e-01,  ...,  2.4121e-01,\n",
       "           -3.1445e-01, -4.2773e-01],\n",
       "          ...,\n",
       "          [-5.3711e-02, -1.1016e+00, -3.5938e-01,  ..., -6.5430e-02,\n",
       "            4.1504e-02,  2.8320e-01],\n",
       "          [ 4.2236e-02, -5.5469e-01, -3.3984e-01,  ..., -1.7090e-02,\n",
       "           -3.1836e-01,  5.7812e-01],\n",
       "          [ 2.2852e-01, -3.5938e-01,  2.8320e-01,  ..., -5.9814e-02,\n",
       "           -1.1084e-01,  4.4727e-01]],\n",
       "\n",
       "         [[-7.8125e-03,  3.3112e-03,  1.0376e-02,  ..., -8.4839e-03,\n",
       "           -1.5625e-02, -1.2112e-04],\n",
       "          [ 1.7548e-03, -4.0625e-01,  1.9727e-01,  ...,  2.0996e-02,\n",
       "            2.4292e-02,  2.2266e-01],\n",
       "          [-1.8164e-01, -2.3047e-01, -3.3594e-01,  ..., -2.2949e-01,\n",
       "            2.5391e-01, -7.8125e-02],\n",
       "          ...,\n",
       "          [-3.8281e-01, -2.8125e-01, -8.2520e-02,  ..., -9.2773e-02,\n",
       "           -1.0938e-01,  1.9238e-01],\n",
       "          [-7.7637e-02, -1.1230e-01, -3.1055e-01,  ...,  6.3965e-02,\n",
       "           -2.3926e-01,  1.6895e-01],\n",
       "          [-8.9844e-02,  1.9043e-01, -2.3145e-01,  ...,  6.0059e-02,\n",
       "           -2.4414e-01, -1.3379e-01]],\n",
       "\n",
       "         [[ 5.6152e-03, -6.9885e-03, -2.1362e-02,  ..., -1.8539e-03,\n",
       "           -3.0060e-03,  4.6997e-03],\n",
       "          [ 4.1260e-02,  1.8652e-01,  2.6953e-01,  ...,  4.8438e-01,\n",
       "            7.1094e-01, -1.1865e-01],\n",
       "          [ 1.7383e-01,  7.1875e-01, -2.8931e-02,  ...,  2.5195e-01,\n",
       "            5.1172e-01,  1.9922e-01],\n",
       "          ...,\n",
       "          [ 2.0898e-01,  3.5938e-01,  4.2725e-02,  ..., -1.5039e-01,\n",
       "            1.2109e-01, -2.2754e-01],\n",
       "          [ 2.9688e-01,  4.8438e-01,  1.5137e-02,  ..., -2.9175e-02,\n",
       "            3.4570e-01,  2.8198e-02],\n",
       "          [ 4.3945e-01,  6.2891e-01,  8.8379e-02,  ...,  4.9023e-01,\n",
       "            6.6406e-01, -1.4551e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.4221e-02, -7.5684e-03,  1.8188e-02,  ...,  1.3203e+00,\n",
       "            4.0039e-02,  3.9844e-01],\n",
       "          [-6.3477e-03, -2.8516e-01, -2.4902e-01,  ..., -3.2422e-01,\n",
       "           -2.0938e+00, -6.8750e-01],\n",
       "          [-1.2207e-01,  4.3701e-02,  2.5146e-02,  ..., -1.8594e+00,\n",
       "           -3.2656e+00,  1.1182e-01],\n",
       "          ...,\n",
       "          [ 7.5000e-01, -3.2031e-01,  6.3281e-01,  ..., -1.7578e+00,\n",
       "           -1.6641e+00, -2.3281e+00],\n",
       "          [ 1.8203e+00,  8.7891e-03,  3.7305e-01,  ..., -3.2656e+00,\n",
       "           -1.0078e+00, -2.0156e+00],\n",
       "          [ 7.0312e-02,  1.1816e-01, -2.8711e-01,  ..., -1.1875e+00,\n",
       "           -8.3594e-01, -6.2500e-01]],\n",
       "\n",
       "         [[-1.1841e-02, -1.3672e-02,  3.0670e-03,  ..., -2.4902e-01,\n",
       "           -1.6113e-01, -1.3770e-01],\n",
       "          [-1.6797e-01, -8.5938e-02,  7.3047e-01,  ..., -2.7148e-01,\n",
       "            2.0898e-01, -2.0312e+00],\n",
       "          [-1.8594e+00,  3.3008e-01, -8.2031e-02,  ...,  5.8203e-01,\n",
       "            2.2031e+00, -1.2812e+00],\n",
       "          ...,\n",
       "          [ 1.3984e+00, -8.0469e-01, -5.8594e-03,  ...,  4.1602e-01,\n",
       "            3.9062e-01,  4.2578e-01],\n",
       "          [ 1.0938e+00, -3.4180e-01,  5.0781e-02,  ...,  1.3281e+00,\n",
       "            1.6719e+00,  1.7500e+00],\n",
       "          [-2.3828e-01,  2.1484e-01,  1.2500e-01,  ..., -1.1484e+00,\n",
       "            1.0859e+00,  3.5889e-02]],\n",
       "\n",
       "         [[-1.7578e-02, -1.1047e-02, -2.7618e-03,  ...,  2.9492e-01,\n",
       "            7.2266e-02, -6.2109e-01],\n",
       "          [-3.4766e-01, -3.5938e-01, -9.9121e-02,  ..., -2.7148e-01,\n",
       "           -1.0312e+00, -1.4766e+00],\n",
       "          [-3.3203e-02,  2.3828e-01,  1.3184e-02,  ..., -5.0781e-01,\n",
       "           -1.7500e+00, -1.5391e+00],\n",
       "          ...,\n",
       "          [ 1.2500e-01,  2.8711e-01,  3.4424e-02,  ...,  3.3789e-01,\n",
       "           -1.1797e+00, -3.7891e-01],\n",
       "          [ 4.4922e-02, -8.7500e-01, -9.4727e-02,  ..., -7.4219e-01,\n",
       "           -6.9922e-01,  8.0078e-01],\n",
       "          [-1.4453e-01, -3.3203e-01, -4.1211e-01,  ..., -2.2500e+00,\n",
       "           -2.4219e+00, -4.4141e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.4604e-03,  1.2329e-02,  2.3041e-03,  ...,  1.8555e-01,\n",
       "            2.3906e+00, -1.8921e-02],\n",
       "          [ 2.7148e-01, -2.4609e-01,  2.5391e-02,  ..., -1.7109e+00,\n",
       "           -1.6016e+00, -3.4570e-01],\n",
       "          [ 3.1641e-01, -3.5352e-01, -3.1641e-01,  ..., -2.9219e+00,\n",
       "           -4.5625e+00, -6.0156e-01],\n",
       "          ...,\n",
       "          [-1.1426e-01, -4.9219e-01,  4.1406e-01,  ...,  5.6641e-01,\n",
       "           -3.4062e+00, -2.2188e+00],\n",
       "          [-6.9141e-01,  9.5703e-02,  3.2031e-01,  ...,  7.5195e-02,\n",
       "           -4.6250e+00, -2.8125e+00],\n",
       "          [ 3.0859e-01,  2.7539e-01,  1.0312e+00,  ..., -1.5156e+00,\n",
       "           -4.0938e+00, -3.2656e+00]],\n",
       "\n",
       "         [[ 3.1494e-02,  7.6599e-03,  2.8076e-03,  ..., -1.0078e+00,\n",
       "            1.1016e+00, -8.3984e-01],\n",
       "          [-8.2422e-01, -1.8848e-01, -2.3438e-01,  ..., -1.6328e+00,\n",
       "            2.5469e+00, -1.6113e-01],\n",
       "          [-3.5742e-01,  2.5586e-01, -1.7822e-02,  ..., -6.9141e-01,\n",
       "            8.7891e-02, -2.6367e-01],\n",
       "          ...,\n",
       "          [ 2.1680e-01, -6.6016e-01,  4.1992e-01,  ...,  1.6406e+00,\n",
       "            1.1719e+00, -1.0938e+00],\n",
       "          [-3.8086e-01,  3.0273e-01,  1.0234e+00,  ...,  2.9219e+00,\n",
       "           -4.1406e-01, -9.2969e-01],\n",
       "          [ 4.7070e-01, -7.2656e-01, -1.4160e-01,  ...,  2.8438e+00,\n",
       "            6.6016e-01,  9.6094e-01]],\n",
       "\n",
       "         [[ 3.9795e-02, -5.9326e-02, -1.4587e-02,  ..., -1.2988e-01,\n",
       "           -2.1094e-01, -5.3223e-02],\n",
       "          [ 6.1328e-01, -1.1523e-01,  2.4316e-01,  ...,  3.0625e+00,\n",
       "            2.1484e-01,  1.0781e+00],\n",
       "          [-4.6143e-02, -2.6758e-01,  1.4038e-02,  ...,  1.4219e+00,\n",
       "            2.4062e+00, -2.5195e-01],\n",
       "          ...,\n",
       "          [-2.4316e-01,  3.3203e-01,  3.0078e-01,  ...,  1.4141e+00,\n",
       "            2.7812e+00,  9.3750e-01],\n",
       "          [ 1.6992e-01,  3.1641e-01,  5.3516e-01,  ...,  1.8047e+00,\n",
       "            1.5391e+00,  1.7344e+00],\n",
       "          [-3.8086e-01, -1.7285e-01,  2.3340e-01,  ...,  1.8984e+00,\n",
       "            1.9688e+00,  6.4062e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 9.6436e-03,  2.2095e-02, -6.7444e-03,  ..., -3.4142e-04,\n",
       "            3.1982e-02, -9.9487e-03],\n",
       "          [-3.9648e-01, -7.8125e-02,  5.3125e-01,  ...,  1.1658e-02,\n",
       "            2.6367e-01,  7.0703e-01],\n",
       "          [ 6.8054e-03,  3.8086e-02,  7.6953e-01,  ..., -9.9609e-02,\n",
       "            2.8516e-01,  3.3008e-01],\n",
       "          ...,\n",
       "          [ 1.0596e-01,  5.6396e-02, -7.0703e-01,  ...,  2.7734e-01,\n",
       "            9.2188e-01,  8.6426e-02],\n",
       "          [-5.7617e-02, -6.9824e-02, -6.4941e-02,  ..., -1.4941e-01,\n",
       "            8.4473e-02, -3.4424e-02],\n",
       "          [-2.0312e-01, -1.1182e-01,  4.9805e-01,  ..., -1.9434e-01,\n",
       "           -8.1055e-02, -1.5137e-01]],\n",
       "\n",
       "         [[-9.7656e-02, -8.5449e-02,  6.7383e-02,  ...,  1.7944e-02,\n",
       "            2.4292e-02, -7.5195e-02],\n",
       "          [ 1.4746e-01, -3.6719e-01, -3.8281e-01,  ..., -1.6504e-01,\n",
       "           -3.8330e-02, -1.1230e-01],\n",
       "          [ 2.0801e-01, -2.0898e-01, -9.3750e-02,  ...,  1.3281e-01,\n",
       "           -9.5703e-02,  3.5352e-01],\n",
       "          ...,\n",
       "          [-1.3672e-01,  1.7383e-01,  8.6426e-02,  ...,  3.5742e-01,\n",
       "           -2.3633e-01,  9.2773e-02],\n",
       "          [ 1.0010e-01, -8.0490e-04, -4.9219e-01,  ...,  2.9883e-01,\n",
       "           -1.0156e-01,  3.3008e-01],\n",
       "          [ 3.4180e-02,  7.3242e-02,  1.0376e-02,  ...,  5.0391e-01,\n",
       "           -1.0303e-01,  7.6660e-02]],\n",
       "\n",
       "         [[ 6.9824e-02,  5.0537e-02, -4.8828e-02,  ...,  3.5889e-02,\n",
       "            2.4780e-02,  2.4536e-02],\n",
       "          [-3.9648e-01, -5.1172e-01,  4.4531e-01,  ...,  3.1641e-01,\n",
       "            3.5742e-01, -1.7773e-01],\n",
       "          [-5.5664e-02, -8.4375e-01,  1.8848e-01,  ..., -5.4297e-01,\n",
       "            9.7266e-01,  2.6562e-01],\n",
       "          ...,\n",
       "          [-1.6602e-01, -4.4678e-02,  4.1992e-01,  ..., -3.9453e-01,\n",
       "            2.1973e-01, -1.1094e+00],\n",
       "          [-7.4707e-02,  5.3711e-02,  8.7891e-01,  ..., -6.0425e-03,\n",
       "            3.5352e-01, -4.1406e-01],\n",
       "          [ 5.1953e-01, -4.8047e-01, -4.9219e-01,  ...,  4.7266e-01,\n",
       "            1.5527e-01, -7.8125e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 7.4463e-03, -3.3203e-02,  7.0953e-04,  ..., -4.0588e-03,\n",
       "            2.1484e-02, -2.2339e-02],\n",
       "          [ 2.0312e-01, -6.9922e-01,  2.8516e-01,  ...,  9.9609e-02,\n",
       "           -1.8945e-01,  2.0020e-01],\n",
       "          [ 2.3242e-01, -1.7090e-01,  4.9133e-03,  ..., -3.7109e-02,\n",
       "            2.5195e-01,  1.1279e-01],\n",
       "          ...,\n",
       "          [ 3.9062e-01, -5.6250e-01,  3.9453e-01,  ..., -3.6621e-02,\n",
       "           -3.2812e-01, -3.7891e-01],\n",
       "          [ 4.8438e-01, -4.0234e-01,  8.1787e-03,  ..., -9.1797e-02,\n",
       "           -2.2461e-01, -2.5977e-01],\n",
       "          [-1.9043e-01, -7.7344e-01,  2.6562e-01,  ..., -2.8906e-01,\n",
       "           -3.2422e-01,  2.0898e-01]],\n",
       "\n",
       "         [[-2.9297e-02,  2.2095e-02, -2.8687e-03,  ...,  1.9226e-03,\n",
       "           -8.8501e-03,  1.6724e-02],\n",
       "          [ 5.4321e-03,  5.1953e-01,  2.7344e-01,  ..., -1.2878e-02,\n",
       "           -2.0898e-01, -3.1836e-01],\n",
       "          [ 2.3828e-01,  3.4570e-01, -6.7749e-03,  ...,  5.6885e-02,\n",
       "            7.8613e-02, -8.8867e-02],\n",
       "          ...,\n",
       "          [ 1.7383e-01,  4.8047e-01,  8.1250e-01,  ...,  1.4160e-01,\n",
       "           -4.5508e-01,  1.8652e-01],\n",
       "          [-1.0547e-01,  2.9688e-01,  4.5898e-01,  ..., -3.1836e-01,\n",
       "           -2.6172e-01,  3.3594e-01],\n",
       "          [ 4.1406e-01,  1.0791e-01, -3.4766e-01,  ...,  2.0898e-01,\n",
       "           -1.1279e-01,  4.6289e-01]],\n",
       "\n",
       "         [[ 1.5182e-03,  7.0190e-03,  1.1108e-02,  ...,  1.7822e-02,\n",
       "           -1.1963e-02,  4.5586e-04],\n",
       "          [-2.6758e-01, -2.8125e-01,  7.1875e-01,  ...,  2.5586e-01,\n",
       "            1.1572e-01, -1.0303e-01],\n",
       "          [-2.8125e-01, -5.5859e-01,  4.6875e-01,  ..., -2.5977e-01,\n",
       "           -3.3594e-01, -2.9883e-01],\n",
       "          ...,\n",
       "          [-1.2891e-01, -3.2227e-01, -2.3535e-01,  ...,  2.0215e-01,\n",
       "            1.3086e-01,  4.5410e-02],\n",
       "          [-8.5938e-02, -5.5420e-02,  1.2598e-01,  ..., -1.4355e-01,\n",
       "            9.1309e-02,  5.6641e-01],\n",
       "          [-2.7148e-01, -2.9883e-01,  8.9844e-01,  ...,  3.1055e-01,\n",
       "            4.9219e-01,  4.5117e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 8.4229e-03, -2.1484e-02, -8.3923e-04,  ...,  8.2520e-02,\n",
       "           -1.3672e-01,  2.3926e-02],\n",
       "          [ 2.3828e-01, -2.2363e-01,  3.3203e-01,  ...,  5.7031e-01,\n",
       "           -6.3281e-01,  1.4453e-01],\n",
       "          [ 4.4336e-01, -5.0781e-01, -4.0527e-02,  ..., -9.1406e-01,\n",
       "           -1.6562e+00, -3.0664e-01],\n",
       "          ...,\n",
       "          [-2.5391e-01, -3.2031e-01,  1.5625e-01,  ...,  3.1128e-03,\n",
       "            2.3730e-01,  7.0312e-01],\n",
       "          [-4.9805e-01,  3.5547e-01, -3.0078e-01,  ..., -4.5166e-02,\n",
       "           -8.6060e-03, -2.8711e-01],\n",
       "          [ 6.2500e-01,  4.1016e-01,  8.1055e-02,  ..., -2.7734e-01,\n",
       "            8.1641e-01, -6.7383e-02]],\n",
       "\n",
       "         [[ 2.7222e-02,  2.0504e-04,  2.8442e-02,  ...,  2.4219e-01,\n",
       "            1.7969e-01,  2.4023e-01],\n",
       "          [ 4.4922e-01,  1.9141e-01, -1.4844e-01,  ...,  4.1992e-01,\n",
       "           -5.5078e-01, -4.9219e-01],\n",
       "          [ 8.5938e-01, -4.3750e-01, -4.4922e-01,  ..., -3.9062e-01,\n",
       "           -1.6094e+00, -5.6250e-01],\n",
       "          ...,\n",
       "          [-1.7773e-01, -2.1387e-01, -3.2617e-01,  ...,  1.2734e+00,\n",
       "            1.4404e-02, -2.5156e+00],\n",
       "          [-1.8359e-01, -8.4961e-02, -1.8457e-01,  ..., -1.8750e+00,\n",
       "           -1.2188e+00, -1.2188e+00],\n",
       "          [ 5.9766e-01,  9.9609e-01, -5.1562e-01,  ..., -1.3359e+00,\n",
       "           -9.6484e-01,  6.1328e-01]],\n",
       "\n",
       "         [[-2.1729e-02,  5.1880e-03,  1.7456e-02,  ..., -1.0234e+00,\n",
       "           -1.9336e-01, -2.0625e+00],\n",
       "          [-2.6172e-01,  1.7773e-01,  1.3281e-01,  ...,  2.1875e+00,\n",
       "           -1.7188e+00,  2.8438e+00],\n",
       "          [-3.2422e-01,  7.0801e-02,  6.9275e-03,  ...,  2.8750e+00,\n",
       "           -1.5156e+00,  4.1562e+00],\n",
       "          ...,\n",
       "          [ 1.3184e-01,  1.4844e-01,  3.3398e-01,  ...,  4.6875e+00,\n",
       "           -3.7031e+00,  3.9531e+00],\n",
       "          [ 3.3008e-01,  3.6914e-01, -2.6562e-01,  ...,  3.7500e+00,\n",
       "           -8.7891e-01,  5.8750e+00],\n",
       "          [ 3.1445e-01,  1.2109e-01,  7.4219e-02,  ...,  2.4062e+00,\n",
       "           -4.3750e-01,  5.4375e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1362e-02,  1.6968e-02, -1.2451e-02,  ...,  1.9434e-01,\n",
       "            4.4678e-02,  1.2402e-01],\n",
       "          [-4.2188e-01, -4.7852e-01,  2.0752e-03,  ..., -1.5547e+00,\n",
       "           -1.8047e+00, -2.2656e+00],\n",
       "          [ 2.7344e-02, -5.1562e-01, -1.4258e-01,  ..., -1.7422e+00,\n",
       "           -2.5312e+00, -2.5312e+00],\n",
       "          ...,\n",
       "          [ 2.7588e-02, -5.1270e-03,  5.4688e-01,  ..., -2.0781e+00,\n",
       "           -1.6250e+00, -1.3828e+00],\n",
       "          [-1.2695e-01,  4.1211e-01, -3.7695e-01,  ..., -9.8828e-01,\n",
       "           -1.7734e+00, -1.9141e+00],\n",
       "          [-2.4902e-01,  3.5352e-01, -3.1738e-02,  ...,  2.7734e-01,\n",
       "           -1.4297e+00,  1.0625e+00]],\n",
       "\n",
       "         [[-1.5503e-02, -4.4861e-03,  3.5706e-03,  ...,  2.4512e-01,\n",
       "            2.6562e+00,  3.1982e-02],\n",
       "          [ 2.3438e-01,  1.8262e-01,  5.1172e-01,  ...,  3.7305e-01,\n",
       "           -3.1406e+00, -1.5430e-01],\n",
       "          [-4.6875e-01, -2.7734e-01,  5.9375e-01,  ..., -5.7983e-03,\n",
       "           -6.0312e+00,  2.6489e-02],\n",
       "          ...,\n",
       "          [ 4.2578e-01, -9.7656e-04,  3.6133e-01,  ..., -2.2500e+00,\n",
       "           -3.8594e+00, -9.3750e-01],\n",
       "          [ 2.3828e-01, -3.7500e-01,  4.5898e-01,  ..., -2.5781e+00,\n",
       "           -5.0000e+00, -1.5234e+00],\n",
       "          [-5.5078e-01, -4.1797e-01,  1.6406e-01,  ..., -2.2500e+00,\n",
       "           -3.4062e+00, -3.2969e+00]],\n",
       "\n",
       "         [[ 6.4087e-03,  6.9427e-04,  2.1484e-02,  ..., -1.5234e-01,\n",
       "           -1.3574e-01,  1.9824e-01],\n",
       "          [-1.6602e-01, -5.8594e-02, -6.5234e-01,  ..., -1.3359e+00,\n",
       "           -5.5859e-01, -7.2266e-01],\n",
       "          [ 1.8311e-03,  3.4961e-01, -5.1953e-01,  ..., -2.3125e+00,\n",
       "           -2.2031e+00, -4.1504e-02],\n",
       "          ...,\n",
       "          [-2.1289e-01,  5.3906e-01,  6.7578e-01,  ..., -2.2656e+00,\n",
       "           -1.1172e+00, -2.7656e+00],\n",
       "          [-1.4844e-01, -1.8457e-01,  1.6602e-01,  ..., -2.0625e+00,\n",
       "            5.0537e-02, -2.1562e+00],\n",
       "          [-3.5156e-02, -4.6094e-01,  6.4453e-02,  ..., -2.0781e+00,\n",
       "            1.1328e+00, -3.2969e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.6113e-02,  3.0762e-02,  1.4221e-02,  ..., -6.0425e-03,\n",
       "           -6.5308e-03, -1.5198e-02],\n",
       "          [-3.5352e-01, -4.2773e-01,  3.9648e-01,  ...,  5.0000e-01,\n",
       "            2.1362e-02,  2.1484e-01],\n",
       "          [ 1.8799e-02, -8.4375e-01,  1.5137e-01,  ...,  5.8350e-02,\n",
       "           -1.3965e-01, -5.3223e-02],\n",
       "          ...,\n",
       "          [-3.6523e-01, -5.7031e-01,  5.2246e-02,  ...,  3.8086e-01,\n",
       "           -9.4531e-01, -2.1240e-02],\n",
       "          [-1.1230e-02, -8.2031e-02,  1.2891e-01,  ...,  5.9766e-01,\n",
       "           -6.8750e-01, -4.4922e-01],\n",
       "          [-1.5918e-01,  7.2754e-02,  3.0469e-01,  ...,  2.5977e-01,\n",
       "            2.1582e-01, -2.0898e-01]],\n",
       "\n",
       "         [[-2.3193e-02, -5.8289e-03, -9.8877e-03,  ..., -1.5869e-02,\n",
       "            5.0735e-04,  2.5146e-02],\n",
       "          [ 2.9688e-01, -5.5176e-02, -2.2559e-01,  ..., -2.1362e-02,\n",
       "            3.9307e-02, -1.1523e-01],\n",
       "          [ 2.4609e-01,  1.1230e-01,  1.5527e-01,  ...,  1.5747e-02,\n",
       "            1.5918e-01,  2.8906e-01],\n",
       "          ...,\n",
       "          [ 3.7305e-01, -3.1250e-01, -2.8711e-01,  ...,  1.0449e-01,\n",
       "           -2.5000e-01, -1.9336e-01],\n",
       "          [ 1.6211e-01,  1.7969e-01,  2.0410e-01,  ...,  4.1260e-02,\n",
       "           -1.3867e-01,  1.3379e-01],\n",
       "          [-1.4551e-01,  6.2500e-01,  1.9141e-01,  ..., -3.9795e-02,\n",
       "            1.7773e-01, -6.7871e-02]],\n",
       "\n",
       "         [[-3.3264e-03, -1.8066e-02,  1.0803e-02,  ...,  7.6599e-03,\n",
       "           -5.2643e-04,  1.2159e-04],\n",
       "          [-4.7461e-01, -3.3008e-01,  4.3555e-01,  ...,  3.6523e-01,\n",
       "           -3.9551e-02, -4.7070e-01],\n",
       "          [-1.0107e-01,  6.5234e-01,  1.4062e-01,  ...,  3.7891e-01,\n",
       "           -4.0039e-02, -1.1719e-01],\n",
       "          ...,\n",
       "          [-6.0547e-01,  7.1289e-02,  2.9883e-01,  ..., -2.6758e-01,\n",
       "            9.2773e-02,  3.3936e-02],\n",
       "          [ 1.3281e+00, -2.1973e-01, -2.2656e-01,  ..., -2.5977e-01,\n",
       "           -5.5859e-01,  6.3281e-01],\n",
       "          [-9.1797e-02, -9.6680e-02, -1.1426e-01,  ...,  4.2773e-01,\n",
       "           -6.4941e-02,  8.6719e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.8799e-02, -2.9541e-02,  6.6528e-03,  ..., -5.1117e-04,\n",
       "            1.3855e-02, -3.3936e-02],\n",
       "          [-9.3750e-02,  2.1289e-01, -4.0039e-01,  ...,  4.1992e-02,\n",
       "            2.9297e-01,  3.1055e-01],\n",
       "          [-5.5176e-02,  5.6152e-02, -5.6885e-02,  ..., -2.6489e-02,\n",
       "            3.1250e-01, -1.1426e-01],\n",
       "          ...,\n",
       "          [ 5.3906e-01, -3.7891e-01,  1.2305e-01,  ..., -1.7773e-01,\n",
       "           -7.8613e-02, -2.4023e-01],\n",
       "          [ 5.2344e-01,  3.1738e-02, -3.4570e-01,  ...,  2.5757e-02,\n",
       "            4.0039e-02,  7.3438e-01],\n",
       "          [ 2.7734e-01, -5.1172e-01, -1.3672e-01,  ..., -3.1641e-01,\n",
       "            7.7148e-02,  1.5137e-01]],\n",
       "\n",
       "         [[-6.0120e-03, -3.4180e-03,  1.4267e-03,  ..., -1.0254e-02,\n",
       "           -4.3640e-03, -9.7046e-03],\n",
       "          [ 8.5449e-02,  5.0781e-02,  1.1230e-01,  ...,  4.6875e-01,\n",
       "            5.2490e-02, -5.3223e-02],\n",
       "          [ 2.5977e-01,  1.3672e-01,  1.5430e-01,  ..., -5.4688e-02,\n",
       "           -1.4258e-01, -2.4219e-01],\n",
       "          ...,\n",
       "          [-3.0273e-01,  2.2461e-02, -4.0039e-02,  ..., -1.3245e-02,\n",
       "           -6.1768e-02,  4.4336e-01],\n",
       "          [ 1.5625e-01,  7.2656e-01, -1.3672e-01,  ...,  2.0020e-01,\n",
       "           -6.5918e-02,  4.9414e-01],\n",
       "          [-7.3828e-01,  4.5508e-01,  5.9375e-01,  ...,  1.0010e-01,\n",
       "            9.6875e-01,  6.8750e-01]],\n",
       "\n",
       "         [[-4.0283e-02, -1.2207e-02,  2.0508e-02,  ...,  1.3611e-02,\n",
       "           -1.4038e-03, -6.3171e-03],\n",
       "          [ 3.4180e-01,  6.2500e-02,  3.6523e-01,  ...,  7.2754e-02,\n",
       "           -2.8516e-01, -1.8750e-01],\n",
       "          [ 1.2891e-01,  2.0996e-01,  1.4648e-01,  ..., -1.7090e-01,\n",
       "            3.3447e-02, -1.5430e-01],\n",
       "          ...,\n",
       "          [-2.1680e-01, -4.0820e-01,  1.3770e-01,  ..., -8.5449e-02,\n",
       "           -1.6699e-01, -1.7090e-01],\n",
       "          [-1.8359e-01, -5.6250e-01, -2.8906e-01,  ..., -2.8320e-01,\n",
       "           -2.6172e-01,  1.3379e-01],\n",
       "          [-6.7969e-01, -1.1133e-01,  4.7852e-01,  ..., -5.1953e-01,\n",
       "           -1.1621e-01, -7.2266e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 4.5898e-02,  4.6875e-02, -1.4709e-02,  ...,  1.8799e-02,\n",
       "            1.0938e-01, -1.2500e-01],\n",
       "          [ 1.8066e-01,  1.6992e-01,  5.5664e-02,  ..., -7.5000e-01,\n",
       "            5.3125e-01, -6.3672e-01],\n",
       "          [-1.7578e-01,  2.1680e-01, -2.3633e-01,  ..., -7.2656e-01,\n",
       "            3.7305e-01, -1.0469e+00],\n",
       "          ...,\n",
       "          [ 1.7969e-01,  1.0742e-01,  2.5586e-01,  ..., -3.3984e-01,\n",
       "            1.3984e+00,  1.0469e+00],\n",
       "          [-9.1797e-02, -3.2031e-01, -8.4766e-01,  ...,  1.8066e-01,\n",
       "            3.9453e-01,  1.6953e+00],\n",
       "          [ 1.9727e-01,  2.0996e-01, -4.6875e-01,  ..., -1.5156e+00,\n",
       "            1.0547e+00, -1.4258e-01]],\n",
       "\n",
       "         [[ 1.9653e-02, -1.3428e-03,  5.5847e-03,  ...,  6.3281e-01,\n",
       "           -3.2031e-01,  4.9219e-01],\n",
       "          [-5.1562e-01,  1.9141e-01, -1.7676e-01,  ...,  1.3203e+00,\n",
       "            8.2031e-02, -5.5078e-01],\n",
       "          [ 1.8750e-01,  4.0430e-01,  5.3516e-01,  ...,  1.0859e+00,\n",
       "           -1.2422e+00, -2.4805e-01],\n",
       "          ...,\n",
       "          [ 2.5391e-01,  9.3750e-02,  2.5391e-01,  ..., -3.5352e-01,\n",
       "           -2.6875e+00, -1.7344e+00],\n",
       "          [ 3.3789e-01,  2.2949e-01, -1.2256e-01,  ..., -1.4609e+00,\n",
       "           -1.7422e+00, -1.2891e+00],\n",
       "          [ 8.8379e-02, -9.7656e-02, -1.8164e-01,  ..., -3.3125e+00,\n",
       "           -1.0938e+00, -1.5312e+00]],\n",
       "\n",
       "         [[ 1.1902e-02,  4.4441e-04, -5.2795e-03,  ...,  7.3730e-02,\n",
       "            8.7109e-01, -5.5078e-01],\n",
       "          [-4.3945e-02, -1.9922e-01, -1.0791e-01,  ...,  4.6094e-01,\n",
       "           -4.6875e+00,  3.2656e+00],\n",
       "          [-7.1094e-01,  2.3242e-01, -2.1606e-02,  ...,  1.2812e+00,\n",
       "           -5.6250e+00,  1.7344e+00],\n",
       "          ...,\n",
       "          [ 2.6562e-01, -5.0000e-01,  3.9551e-02,  ..., -1.1035e-01,\n",
       "           -4.1250e+00,  1.1328e+00],\n",
       "          [ 3.5742e-01, -9.8633e-02, -7.8125e-02,  ..., -7.8125e-02,\n",
       "           -2.7188e+00, -9.5703e-01],\n",
       "          [-1.2305e-01,  3.4180e-01,  9.2773e-02,  ...,  1.5469e+00,\n",
       "           -3.1250e+00,  1.9727e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1118e-02,  6.0425e-03,  7.9346e-03,  ...,  4.9805e-01,\n",
       "            1.1875e+00,  5.0391e-01],\n",
       "          [ 5.8984e-01, -3.6719e-01, -4.1016e-01,  ..., -9.6875e-01,\n",
       "           -4.1406e-01, -7.1094e-01],\n",
       "          [ 1.1816e-01, -4.1748e-02,  3.3594e-01,  ...,  2.2095e-02,\n",
       "           -2.2812e+00, -1.3516e+00],\n",
       "          ...,\n",
       "          [-7.9688e-01, -1.4941e-01,  1.8066e-02,  ...,  5.2979e-02,\n",
       "            1.3965e-01, -1.0938e+00],\n",
       "          [ 9.5215e-02, -1.9629e-01, -4.4922e-01,  ...,  2.7734e-01,\n",
       "           -1.8672e+00, -7.4609e-01],\n",
       "          [ 5.3125e-01,  1.1963e-02, -6.8359e-03,  ...,  5.0781e-01,\n",
       "           -2.5312e+00, -6.7578e-01]],\n",
       "\n",
       "         [[ 2.4292e-02, -4.5654e-02,  4.9316e-02,  ...,  2.8906e-01,\n",
       "            2.8516e-01, -2.1875e+00],\n",
       "          [-1.8359e-01,  1.1963e-01, -6.8359e-03,  ..., -4.5898e-02,\n",
       "            1.3438e+00,  2.7500e+00],\n",
       "          [ 5.8594e-01, -7.9590e-02, -3.8086e-01,  ..., -8.3203e-01,\n",
       "            1.1875e+00,  3.7188e+00],\n",
       "          ...,\n",
       "          [ 9.7656e-02,  2.2168e-01,  5.1172e-01,  ..., -7.1875e-01,\n",
       "            1.3203e+00,  2.9062e+00],\n",
       "          [-5.8594e-01,  1.9629e-01, -1.9336e-01,  ..., -5.5078e-01,\n",
       "           -8.8672e-01,  3.0156e+00],\n",
       "          [ 2.1387e-01, -1.3477e-01,  4.8828e-02,  ..., -7.7344e-01,\n",
       "           -7.6953e-01,  2.7188e+00]],\n",
       "\n",
       "         [[-3.3203e-02,  2.3071e-02, -5.4321e-03,  ...,  3.2031e-01,\n",
       "           -2.1582e-01, -8.4766e-01],\n",
       "          [ 2.4902e-01, -4.1016e-02,  3.5547e-01,  ..., -1.3125e+00,\n",
       "            6.4844e-01,  1.2578e+00],\n",
       "          [-6.3672e-01,  5.2344e-01,  5.9814e-02,  ..., -2.1582e-01,\n",
       "           -3.3203e-01,  2.0156e+00],\n",
       "          ...,\n",
       "          [-9.2773e-02,  1.7383e-01, -7.4219e-02,  ..., -1.9141e-01,\n",
       "            1.0000e+00,  8.7109e-01],\n",
       "          [ 2.3145e-01,  3.3203e-01, -1.8750e-01,  ..., -1.3359e+00,\n",
       "            1.2500e+00,  1.7969e+00],\n",
       "          [ 1.7285e-01, -1.0234e+00, -4.9023e-01,  ...,  1.3770e-01,\n",
       "            3.3281e+00,  2.8125e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-1.7212e-02, -5.1880e-03,  1.0437e-02,  ...,  1.0681e-02,\n",
       "            2.6245e-02,  3.0884e-02],\n",
       "          [-3.0469e-01,  3.5742e-01,  8.3008e-02,  ..., -7.4219e-02,\n",
       "           -1.7871e-01, -1.5820e-01],\n",
       "          [-4.5117e-01,  4.6680e-01, -5.9082e-02,  ...,  1.4062e-01,\n",
       "           -5.7422e-01, -5.8594e-01],\n",
       "          ...,\n",
       "          [-1.5625e-01,  9.7656e-02,  3.0664e-01,  ..., -1.0254e-01,\n",
       "            2.2363e-01, -3.2031e-01],\n",
       "          [-3.5156e-01, -3.3398e-01,  1.9727e-01,  ...,  6.2500e-02,\n",
       "           -3.9844e-01,  1.8750e-01],\n",
       "          [ 3.9453e-01,  3.0078e-01,  4.4531e-01,  ..., -6.9922e-01,\n",
       "           -6.9141e-01,  1.1670e-01]],\n",
       "\n",
       "         [[ 5.8899e-03,  3.0640e-02, -1.3916e-02,  ...,  1.6016e-01,\n",
       "            1.0864e-02,  2.4567e-03],\n",
       "          [-1.5625e-01, -2.4414e-01,  2.3438e-01,  ..., -5.4688e-01,\n",
       "           -1.0498e-01, -5.7031e-01],\n",
       "          [-2.0898e-01,  1.3281e-01, -1.0791e-01,  ..., -1.9629e-01,\n",
       "            1.6113e-01,  5.9570e-02],\n",
       "          ...,\n",
       "          [-2.9492e-01, -2.2559e-01,  1.9531e-01,  ..., -3.3008e-01,\n",
       "           -9.0625e-01, -1.6211e-01],\n",
       "          [-1.3770e-01, -1.3379e-01,  3.4570e-01,  ..., -2.7539e-01,\n",
       "           -6.6797e-01,  1.1377e-01],\n",
       "          [ 1.9897e-02,  5.0391e-01,  4.7656e-01,  ..., -6.1279e-02,\n",
       "           -7.6953e-01, -1.3379e-01]],\n",
       "\n",
       "         [[-1.1673e-03, -1.1780e-02,  2.0142e-03,  ..., -8.6670e-03,\n",
       "           -3.3722e-03, -2.1515e-03],\n",
       "          [-3.6719e-01,  2.9492e-01, -3.5352e-01,  ...,  3.3594e-01,\n",
       "           -4.2578e-01, -1.2031e+00],\n",
       "          [-1.2207e-01,  4.1406e-01, -1.0498e-01,  ...,  7.8125e-01,\n",
       "           -1.2422e+00, -4.8047e-01],\n",
       "          ...,\n",
       "          [-4.1406e-01,  6.0547e-01, -2.2070e-01,  ...,  2.8809e-02,\n",
       "           -1.4355e-01, -7.5000e-01],\n",
       "          [ 2.2949e-01,  2.4805e-01,  1.1621e-01,  ...,  7.1875e-01,\n",
       "           -8.0078e-01, -2.2559e-01],\n",
       "          [-6.7578e-01,  2.6758e-01, -4.0234e-01,  ...,  1.7383e-01,\n",
       "           -1.0156e+00,  3.4332e-03]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-3.3379e-04,  7.7820e-03,  3.3264e-03,  ..., -3.1433e-03,\n",
       "            1.9165e-02, -3.2349e-03],\n",
       "          [-2.3242e-01, -4.0234e-01, -2.7344e-01,  ..., -4.4336e-01,\n",
       "            2.9688e-01,  6.2891e-01],\n",
       "          [ 3.9844e-01,  2.1680e-01,  9.4531e-01,  ..., -2.5977e-01,\n",
       "            1.6602e-01,  1.7773e-01],\n",
       "          ...,\n",
       "          [ 2.7344e-01, -8.9722e-03,  2.8711e-01,  ..., -9.8438e-01,\n",
       "           -1.0234e+00, -5.4297e-01],\n",
       "          [ 1.8555e-01,  6.1719e-01,  5.0049e-02,  ..., -5.3516e-01,\n",
       "           -7.7734e-01, -9.2578e-01],\n",
       "          [ 9.0942e-03,  4.5117e-01,  4.0234e-01,  ..., -1.1719e-01,\n",
       "           -8.3984e-01, -2.1094e-01]],\n",
       "\n",
       "         [[-1.5182e-03,  2.5787e-03,  4.1504e-03,  ...,  7.5531e-04,\n",
       "            3.5645e-02, -2.8442e-02],\n",
       "          [ 6.3477e-02,  3.1641e-01, -2.1973e-02,  ...,  1.7969e-01,\n",
       "           -3.3203e-01, -1.7871e-01],\n",
       "          [-1.0254e-01, -2.2363e-01, -3.4961e-01,  ...,  2.0898e-01,\n",
       "            1.2012e-01,  1.0254e-01],\n",
       "          ...,\n",
       "          [-2.8711e-01, -4.3359e-01, -3.1641e-01,  ...,  1.4062e-01,\n",
       "            2.2266e-01,  4.0430e-01],\n",
       "          [-7.9956e-03,  3.4424e-02, -4.6094e-01,  ...,  2.3926e-01,\n",
       "            3.1250e-01, -9.0820e-02],\n",
       "          [ 6.6016e-01, -2.9883e-01, -4.9414e-01,  ..., -2.0117e-01,\n",
       "            4.5166e-03, -1.5918e-01]],\n",
       "\n",
       "         [[ 2.8442e-02, -4.3701e-02,  2.7344e-02,  ..., -4.3213e-02,\n",
       "           -2.1118e-02,  2.2583e-02],\n",
       "          [ 1.3477e-01,  5.4297e-01,  3.3789e-01,  ...,  7.2266e-02,\n",
       "           -2.8711e-01,  1.9727e-01],\n",
       "          [-9.3262e-02,  4.5898e-01,  6.4062e-01,  ...,  1.2207e-01,\n",
       "            1.0010e-02,  1.8188e-02],\n",
       "          ...,\n",
       "          [-4.2969e-01, -1.9531e-01,  1.2589e-03,  ...,  5.1953e-01,\n",
       "            2.4512e-01, -1.3379e-01],\n",
       "          [-5.7031e-01, -9.7168e-02,  1.2695e-01,  ...,  8.5938e-02,\n",
       "            3.1641e-01, -3.1836e-01],\n",
       "          [-2.9785e-02, -3.3594e-01,  1.2695e-01,  ..., -1.7480e-01,\n",
       "            4.5312e-01, -3.0078e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 2.7954e-02, -2.2583e-02, -2.5787e-03,  ..., -2.7539e-01,\n",
       "           -3.2959e-02,  3.2617e-01],\n",
       "          [-1.3281e-01, -4.3555e-01, -2.1094e-01,  ..., -2.8320e-01,\n",
       "           -3.3594e-01, -2.7148e-01],\n",
       "          [-3.6523e-01,  1.9629e-01, -3.9258e-01,  ...,  4.9219e-01,\n",
       "           -4.1992e-01, -1.1094e+00],\n",
       "          ...,\n",
       "          [ 2.8516e-01, -7.0801e-02,  2.0801e-01,  ...,  8.0859e-01,\n",
       "            1.4219e+00, -3.6328e-01],\n",
       "          [ 1.6797e-01,  5.2344e-01,  1.4258e-01,  ...,  8.5547e-01,\n",
       "           -7.1289e-02, -2.5000e-01],\n",
       "          [ 2.8125e-01, -1.3184e-01, -3.4766e-01,  ..., -2.7222e-02,\n",
       "            1.3984e+00, -1.4453e+00]],\n",
       "\n",
       "         [[ 1.4404e-02, -1.0132e-02, -5.0659e-03,  ..., -8.9844e-02,\n",
       "           -1.9897e-02, -1.0010e-01],\n",
       "          [-1.2695e-01,  1.7090e-01,  2.1875e-01,  ...,  3.3447e-02,\n",
       "           -1.0938e+00,  9.0625e-01],\n",
       "          [-3.3203e-02,  1.5625e-02,  2.2168e-01,  ..., -6.7578e-01,\n",
       "            1.2578e+00,  1.5312e+00],\n",
       "          ...,\n",
       "          [-2.4805e-01, -1.5430e-01, -2.9883e-01,  ..., -1.2812e+00,\n",
       "            3.0859e-01,  1.3203e+00],\n",
       "          [-2.9688e-01,  9.5312e-01, -4.5312e-01,  ..., -5.8984e-01,\n",
       "            1.4688e+00,  2.0156e+00],\n",
       "          [ 2.0898e-01,  1.0156e-01, -5.8594e-01,  ...,  3.9062e-01,\n",
       "            7.8906e-01,  1.9922e+00]],\n",
       "\n",
       "         [[ 9.8267e-03, -1.1047e-02, -5.5542e-03,  ..., -6.4941e-02,\n",
       "            5.7983e-03, -9.8828e-01],\n",
       "          [-4.9609e-01, -4.8096e-02, -2.3535e-01,  ..., -2.7930e-01,\n",
       "           -1.6406e+00, -1.4375e+00],\n",
       "          [ 7.8125e-03,  5.1562e-01,  1.6211e-01,  ...,  1.6797e+00,\n",
       "            3.7305e-01, -5.6250e-01],\n",
       "          ...,\n",
       "          [ 1.8066e-01,  9.5215e-02, -6.1719e-01,  ...,  4.3555e-01,\n",
       "           -6.2891e-01, -4.4727e-01],\n",
       "          [ 3.2031e-01, -1.6602e-02,  7.8125e-01,  ...,  6.7188e-01,\n",
       "            7.1484e-01,  3.0625e+00],\n",
       "          [-6.4453e-01,  1.1797e+00,  3.3789e-01,  ...,  3.0625e+00,\n",
       "           -3.9453e-01,  3.7344e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-3.1738e-02, -1.7944e-02,  1.9531e-02,  ..., -1.9336e-01,\n",
       "           -2.3633e-01, -7.9297e-01],\n",
       "          [-8.8281e-01, -6.1719e-01, -3.7109e-01,  ..., -1.0986e-01,\n",
       "            1.0391e+00,  1.3594e+00],\n",
       "          [ 3.9648e-01, -6.4453e-01, -4.0430e-01,  ...,  1.0352e-01,\n",
       "            7.7344e-01,  1.2969e+00],\n",
       "          ...,\n",
       "          [ 3.1250e-01, -5.1172e-01, -3.8281e-01,  ...,  8.0859e-01,\n",
       "            1.5527e-01,  6.8750e-01],\n",
       "          [ 3.6377e-02, -8.6719e-01, -5.0781e-01,  ...,  9.5703e-01,\n",
       "            3.7695e-01,  1.4922e+00],\n",
       "          [ 2.6953e-01, -9.9219e-01,  3.8867e-01,  ...,  4.0234e-01,\n",
       "            1.4297e+00,  1.3359e+00]],\n",
       "\n",
       "         [[-1.8677e-02,  1.0071e-02, -2.6978e-02,  ...,  2.6562e-01,\n",
       "           -8.6719e-01, -5.4297e-01],\n",
       "          [ 5.4688e-02,  2.4805e-01,  4.0625e-01,  ..., -2.2969e+00,\n",
       "           -9.3359e-01, -2.3145e-01],\n",
       "          [-1.6992e-01,  6.5625e-01, -1.7188e-01,  ..., -2.1562e+00,\n",
       "            1.5234e+00,  7.8516e-01],\n",
       "          ...,\n",
       "          [ 3.1836e-01, -4.9414e-01, -2.6758e-01,  ..., -1.6875e+00,\n",
       "           -1.9727e-01, -1.2734e+00],\n",
       "          [ 1.2188e+00,  1.2207e-02,  6.7188e-01,  ..., -1.5703e+00,\n",
       "            2.0469e+00,  1.0781e+00],\n",
       "          [-2.7539e-01, -1.6016e-01, -1.2598e-01,  ..., -2.2188e+00,\n",
       "            4.5508e-01,  2.6758e-01]],\n",
       "\n",
       "         [[-1.0742e-02, -2.1240e-02, -3.3188e-04,  ..., -1.2329e-02,\n",
       "            1.7188e-01, -3.3398e-01],\n",
       "          [-4.3945e-02, -3.8477e-01,  6.7578e-01,  ..., -1.6875e+00,\n",
       "           -9.2578e-01, -2.5625e+00],\n",
       "          [-2.3438e-02, -2.3633e-01,  4.2773e-01,  ..., -7.3438e-01,\n",
       "           -1.4688e+00, -4.0312e+00],\n",
       "          ...,\n",
       "          [ 4.1406e-01,  6.5918e-03,  3.6133e-01,  ...,  2.8906e-01,\n",
       "           -9.5312e-01, -3.2969e+00],\n",
       "          [ 6.2500e-01, -1.6211e-01,  3.3447e-02,  ..., -5.6250e-01,\n",
       "           -7.7734e-01, -8.7109e-01],\n",
       "          [-4.6484e-01, -1.4062e-01,  5.0781e-01,  ...,  1.5820e-01,\n",
       "           -5.1562e-01, -6.0938e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-2.3071e-02,  4.9072e-02,  5.0293e-02,  ...,  7.1411e-03,\n",
       "           -4.2725e-03, -1.6968e-02],\n",
       "          [-1.5381e-02, -3.4570e-01, -3.1055e-01,  ...,  4.2480e-02,\n",
       "           -4.4141e-01,  2.0020e-01],\n",
       "          [-5.5469e-01, -3.0859e-01,  1.6895e-01,  ..., -9.6094e-01,\n",
       "           -3.2031e-01,  1.7969e-01],\n",
       "          ...,\n",
       "          [-1.6504e-01, -4.0039e-01,  2.7539e-01,  ..., -3.7500e-01,\n",
       "           -3.4668e-02,  3.5889e-02],\n",
       "          [-3.6328e-01, -5.0293e-02,  3.6377e-02,  ..., -8.7891e-02,\n",
       "            2.3828e-01,  2.2852e-01],\n",
       "          [ 6.2500e-02,  2.1240e-02,  4.3945e-02,  ..., -1.3770e-01,\n",
       "           -1.1426e-01, -2.4121e-01]],\n",
       "\n",
       "         [[-1.7395e-03, -5.9204e-03, -1.4648e-02,  ...,  2.9541e-02,\n",
       "           -7.6675e-04, -3.5477e-04],\n",
       "          [ 1.4551e-01,  2.2461e-02, -4.9805e-02,  ..., -6.1328e-01,\n",
       "            1.8750e-01, -2.3340e-01],\n",
       "          [-3.9551e-02,  2.6172e-01,  2.8516e-01,  ..., -3.6523e-01,\n",
       "            7.0312e-01,  2.9785e-02],\n",
       "          ...,\n",
       "          [ 9.6191e-02, -1.0010e-01,  8.8867e-02,  ...,  2.4658e-02,\n",
       "            1.3379e-01, -3.8867e-01],\n",
       "          [ 4.0820e-01, -1.0132e-02, -2.1289e-01,  ..., -3.5547e-01,\n",
       "            1.2891e-01,  1.1182e-01],\n",
       "          [ 2.0898e-01, -5.5469e-01, -8.3008e-02,  ..., -9.8438e-01,\n",
       "            6.2109e-01, -1.1475e-01]],\n",
       "\n",
       "         [[-2.6611e-02,  5.6152e-03,  1.2512e-02,  ..., -1.5625e-02,\n",
       "            1.7090e-02,  1.5625e-02],\n",
       "          [ 1.1719e-01, -3.2422e-01, -1.6016e-01,  ...,  1.7871e-01,\n",
       "           -2.3340e-01,  2.6562e-01],\n",
       "          [ 2.2559e-01, -3.7500e-01, -2.0215e-01,  ..., -6.5918e-02,\n",
       "           -3.3984e-01,  4.7070e-01],\n",
       "          ...,\n",
       "          [ 2.4121e-01, -1.6504e-01,  4.1992e-01,  ..., -4.1797e-01,\n",
       "           -3.9258e-01, -1.1377e-01],\n",
       "          [ 3.1128e-02,  6.6406e-02,  3.1836e-01,  ..., -3.1836e-01,\n",
       "           -2.1875e-01, -4.3164e-01],\n",
       "          [-2.2461e-01,  6.6797e-01, -4.6997e-03,  ..., -2.0703e-01,\n",
       "           -6.9336e-02,  2.8320e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.6377e-02,  2.3438e-02,  1.3428e-02,  ...,  7.6294e-03,\n",
       "            2.2461e-02, -5.2002e-02],\n",
       "          [-8.8379e-02,  7.3047e-01,  1.8066e-01,  ..., -1.1865e-01,\n",
       "            9.9121e-02,  2.3438e-01],\n",
       "          [-5.2344e-01, -2.4316e-01, -4.2578e-01,  ..., -7.2754e-02,\n",
       "            2.2754e-01,  6.6797e-01],\n",
       "          ...,\n",
       "          [ 1.5039e-01,  5.9326e-02, -4.2969e-01,  ...,  2.5391e-01,\n",
       "           -1.5918e-01,  4.3359e-01],\n",
       "          [ 4.2725e-02,  1.6211e-01, -3.5547e-01,  ..., -1.5234e-01,\n",
       "           -3.4424e-02,  1.4941e-01],\n",
       "          [ 2.9688e-01, -4.9316e-02, -4.6484e-01,  ..., -1.9336e-01,\n",
       "           -3.9648e-01, -5.1562e-01]],\n",
       "\n",
       "         [[ 1.9409e-02, -2.1973e-02, -9.1553e-03,  ...,  3.6011e-03,\n",
       "            2.2095e-02,  4.5166e-03],\n",
       "          [ 3.2812e-01, -1.4648e-01,  5.3906e-01,  ..., -3.7305e-01,\n",
       "            4.0234e-01,  5.0781e-01],\n",
       "          [ 4.0820e-01, -4.6680e-01,  3.4180e-01,  ...,  1.7383e-01,\n",
       "            4.1406e-01,  2.9492e-01],\n",
       "          ...,\n",
       "          [ 4.5508e-01, -5.0781e-01,  4.7852e-01,  ...,  3.5352e-01,\n",
       "            2.7539e-01, -1.1035e-01],\n",
       "          [ 2.1191e-01, -5.4297e-01,  4.2773e-01,  ...,  6.6797e-01,\n",
       "            1.2305e-01,  7.5684e-02],\n",
       "          [ 2.6611e-02, -2.1680e-01,  1.7480e-01,  ...,  1.5527e-01,\n",
       "           -2.2070e-01,  9.2285e-02]],\n",
       "\n",
       "         [[-4.3945e-02, -1.4465e-02, -1.4954e-02,  ..., -1.3611e-02,\n",
       "            1.4954e-02,  6.3782e-03],\n",
       "          [-8.3008e-02, -1.7090e-01,  2.4707e-01,  ...,  2.8125e-01,\n",
       "           -1.7285e-01,  1.9824e-01],\n",
       "          [ 3.5352e-01, -8.2422e-01,  6.4453e-01,  ...,  5.8594e-02,\n",
       "           -2.8076e-02,  1.0681e-02],\n",
       "          ...,\n",
       "          [ 2.8125e-01, -1.8457e-01,  6.0156e-01,  ...,  6.6406e-01,\n",
       "           -3.5547e-01, -3.9062e-01],\n",
       "          [-7.4219e-02, -1.1353e-02,  9.2188e-01,  ...,  3.7891e-01,\n",
       "           -6.4062e-01, -8.4766e-01],\n",
       "          [-3.5938e-01, -8.6914e-02,  6.5234e-01,  ...,  5.1562e-01,\n",
       "            6.1328e-01, -3.0469e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.8910e-03,  1.9409e-02,  1.7456e-02,  ..., -2.2852e-01,\n",
       "            1.4746e-01, -1.0840e-01],\n",
       "          [ 8.9844e-01,  3.2227e-01,  5.0000e-01,  ..., -8.9453e-01,\n",
       "           -5.9375e-01,  4.4141e-01],\n",
       "          [-1.7578e-02, -1.8750e-01,  5.1270e-02,  ..., -3.8086e-01,\n",
       "           -1.7090e-01,  6.7188e-01],\n",
       "          ...,\n",
       "          [ 4.3359e-01,  5.3906e-01, -4.1016e-01,  ..., -4.8438e-01,\n",
       "            1.1641e+00, -1.3281e-01],\n",
       "          [ 1.1016e+00, -1.6602e-02,  5.3516e-01,  ...,  2.5977e-01,\n",
       "           -6.1279e-02,  1.2578e+00],\n",
       "          [ 7.8125e-01, -9.1406e-01, -1.5918e-01,  ..., -3.0469e-01,\n",
       "           -5.6885e-02,  6.1719e-01]],\n",
       "\n",
       "         [[-2.0264e-02,  5.3406e-03,  6.2256e-03,  ..., -2.2031e+00,\n",
       "            2.5391e-01,  4.8242e-01],\n",
       "          [-2.3340e-01,  4.7461e-01, -3.2031e-01,  ...,  2.1719e+00,\n",
       "           -2.8125e-01, -2.5781e+00],\n",
       "          [ 3.2227e-02,  4.2383e-01,  4.5312e-01,  ...,  3.7969e+00,\n",
       "           -1.4375e+00, -3.4219e+00],\n",
       "          ...,\n",
       "          [-4.8828e-04, -2.6978e-02,  4.9609e-01,  ...,  2.8594e+00,\n",
       "           -1.9141e+00, -2.5625e+00],\n",
       "          [-4.5117e-01, -2.8125e-01,  1.2695e-02,  ...,  4.8750e+00,\n",
       "           -2.5781e+00, -3.1406e+00],\n",
       "          [-2.6953e-01, -2.6953e-01, -3.4570e-01,  ...,  3.2969e+00,\n",
       "           -2.3281e+00, -2.4219e+00]],\n",
       "\n",
       "         [[-5.2795e-03, -3.9673e-03,  3.2959e-02,  ...,  6.6406e-02,\n",
       "           -5.2002e-02, -1.9434e-01],\n",
       "          [-4.7852e-02, -7.1875e-01,  8.7891e-02,  ...,  1.3281e-01,\n",
       "           -2.4531e+00,  9.4922e-01],\n",
       "          [ 2.3730e-01, -1.1719e-01, -2.9053e-02,  ..., -2.0898e-01,\n",
       "           -2.2969e+00,  9.5367e-04],\n",
       "          ...,\n",
       "          [ 3.1641e-01,  4.2773e-01, -6.9141e-01,  ...,  1.0391e+00,\n",
       "            1.3359e+00, -1.1406e+00],\n",
       "          [ 6.6797e-01,  3.7500e-01,  6.7188e-01,  ..., -4.4922e-01,\n",
       "            7.1094e-01, -1.6719e+00],\n",
       "          [ 2.1387e-01, -5.5859e-01,  2.9785e-02,  ..., -1.1875e+00,\n",
       "           -1.6504e-01, -6.5234e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5503e-02, -8.9722e-03, -1.7578e-02,  ..., -1.8281e+00,\n",
       "            2.2031e+00, -6.2109e-01],\n",
       "          [-2.7344e-02, -1.8652e-01,  4.4922e-02,  ...,  1.7344e+00,\n",
       "           -9.6094e-01,  2.4062e+00],\n",
       "          [ 3.1494e-02,  5.7617e-02,  3.0029e-02,  ...,  3.8438e+00,\n",
       "           -3.8125e+00,  2.4688e+00],\n",
       "          ...,\n",
       "          [-4.4141e-01,  3.2227e-02, -6.1035e-02,  ...,  3.3906e+00,\n",
       "           -1.8047e+00,  1.3828e+00],\n",
       "          [ 2.1387e-01,  4.0625e-01,  1.0254e-02,  ...,  5.0625e+00,\n",
       "           -3.6250e+00,  1.0312e+00],\n",
       "          [-2.2656e-01,  2.2852e-01, -4.6484e-01,  ...,  5.0312e+00,\n",
       "           -1.1172e+00, -6.4062e-01]],\n",
       "\n",
       "         [[-2.4658e-02,  1.1475e-02, -6.4697e-03,  ..., -1.4062e-01,\n",
       "            1.7285e-01, -7.4219e-01],\n",
       "          [ 1.0645e-01, -4.4727e-01, -2.1680e-01,  ..., -3.5156e-02,\n",
       "           -1.3672e-01,  2.0469e+00],\n",
       "          [-2.5977e-01, -1.0352e-01, -5.3467e-02,  ...,  2.5000e-01,\n",
       "           -6.5625e-01,  2.7344e+00],\n",
       "          ...,\n",
       "          [-2.9297e-03,  4.6875e-01, -2.3535e-01,  ...,  6.9141e-01,\n",
       "            1.0156e+00,  1.3203e+00],\n",
       "          [-2.2266e-01,  6.5234e-01, -3.4180e-02,  ...,  6.6406e-01,\n",
       "           -6.9824e-02,  1.2656e+00],\n",
       "          [-5.3906e-01, -3.5547e-01,  9.4727e-02,  ..., -2.2070e-01,\n",
       "           -1.3438e+00, -1.8799e-02]],\n",
       "\n",
       "         [[ 1.0925e-02,  1.1841e-02,  7.8125e-03,  ..., -2.7588e-02,\n",
       "           -1.4844e-01,  5.4932e-02],\n",
       "          [-2.6367e-01, -6.9336e-02, -4.2969e-01,  ...,  4.5117e-01,\n",
       "            1.3438e+00, -1.6016e+00],\n",
       "          [ 6.8750e-01,  1.2793e-01,  6.7188e-01,  ...,  2.1118e-02,\n",
       "            9.2969e-01,  1.7969e-01],\n",
       "          ...,\n",
       "          [-1.2891e+00,  1.2344e+00,  7.5781e-01,  ...,  4.0430e-01,\n",
       "            5.8203e-01,  7.1094e-01],\n",
       "          [-8.9844e-01,  8.5938e-02,  2.0801e-01,  ..., -8.3594e-01,\n",
       "            5.3906e-01,  1.0156e+00],\n",
       "          [ 1.9727e-01, -4.2578e-01,  4.0234e-01,  ..., -1.9062e+00,\n",
       "            3.8867e-01, -2.0996e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 4.4434e-02,  9.8877e-03,  1.0803e-02,  ...,  9.8419e-04,\n",
       "           -2.1515e-03, -6.1523e-02],\n",
       "          [ 2.6367e-01, -1.7480e-01,  1.2158e-01,  ..., -4.1406e-01,\n",
       "            2.5586e-01,  2.6562e-01],\n",
       "          [ 3.5547e-01,  5.7031e-01, -3.0151e-02,  ...,  4.8828e-02,\n",
       "           -4.9072e-02,  5.6763e-03],\n",
       "          ...,\n",
       "          [ 6.5234e-01,  1.1719e-01, -1.8066e-01,  ..., -3.0396e-02,\n",
       "           -3.8574e-02,  6.6406e-01],\n",
       "          [ 3.8477e-01, -6.8359e-02,  1.5820e-01,  ..., -2.8516e-01,\n",
       "           -3.2227e-01,  6.4062e-01],\n",
       "          [ 3.2031e-01,  9.3994e-03,  7.8125e-03,  ...,  2.5391e-01,\n",
       "            7.8906e-01, -2.5000e-01]],\n",
       "\n",
       "         [[ 2.7161e-03, -2.6245e-03,  7.4768e-03,  ..., -6.3705e-04,\n",
       "            3.2806e-03,  3.7842e-03],\n",
       "          [ 1.4954e-02, -1.3086e-01, -6.4453e-02,  ..., -4.3945e-02,\n",
       "           -1.0840e-01,  5.1953e-01],\n",
       "          [ 1.9336e-01, -6.9531e-01, -6.7969e-01,  ...,  9.1309e-02,\n",
       "           -4.4922e-01,  3.2422e-01],\n",
       "          ...,\n",
       "          [-2.7222e-02,  9.9609e-01,  1.1084e-01,  ..., -3.6523e-01,\n",
       "           -5.6152e-02, -2.9688e-01],\n",
       "          [ 8.3496e-02,  4.3359e-01,  3.1055e-01,  ..., -1.4746e-01,\n",
       "            4.2969e-01, -2.1094e-01],\n",
       "          [ 1.3281e-01,  8.4766e-01,  4.0625e-01,  ..., -3.1445e-01,\n",
       "           -8.1250e-01, -6.3281e-01]],\n",
       "\n",
       "         [[ 7.8735e-03, -4.8218e-03, -2.4719e-03,  ...,  1.5503e-02,\n",
       "            5.1575e-03, -1.1215e-03],\n",
       "          [-8.0566e-02, -1.2793e-01,  2.0801e-01,  ...,  4.2578e-01,\n",
       "            5.2002e-02,  1.0156e-01],\n",
       "          [-1.2451e-01,  2.6953e-01,  4.3945e-01,  ...,  6.2891e-01,\n",
       "           -2.1729e-02,  1.7773e-01],\n",
       "          ...,\n",
       "          [ 2.2949e-01, -3.3203e-02, -5.1562e-01,  ...,  9.6680e-02,\n",
       "            2.6758e-01,  3.2422e-01],\n",
       "          [-4.6387e-02, -1.6113e-01, -3.5352e-01,  ..., -7.5195e-02,\n",
       "           -1.6113e-01, -2.8516e-01],\n",
       "          [-2.0312e-01,  2.6172e-01, -1.3477e-01,  ...,  6.6016e-01,\n",
       "            2.6172e-01,  6.9141e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.7212e-02, -2.4414e-02, -8.5449e-03,  ...,  7.2937e-03,\n",
       "           -3.1250e-02,  2.9907e-02],\n",
       "          [ 3.7598e-02,  2.1582e-01,  3.6719e-01,  ..., -2.1875e-01,\n",
       "            3.5938e-01,  1.1133e-01],\n",
       "          [ 1.5527e-01,  1.1621e-01, -3.1055e-01,  ..., -2.7148e-01,\n",
       "           -2.9907e-03,  2.6758e-01],\n",
       "          ...,\n",
       "          [ 2.0117e-01,  3.4912e-02,  1.3184e-01,  ..., -2.1851e-02,\n",
       "           -1.5430e-01,  9.0625e-01],\n",
       "          [ 4.3945e-01, -1.5039e-01,  1.3184e-01,  ..., -1.5625e-01,\n",
       "           -4.1602e-01,  6.6016e-01],\n",
       "          [ 2.5781e-01,  6.5918e-02,  2.1387e-01,  ..., -3.4180e-01,\n",
       "           -5.6250e-01,  9.4922e-01]],\n",
       "\n",
       "         [[-3.0823e-03, -2.5513e-02,  9.4604e-03,  ..., -2.1362e-02,\n",
       "           -5.3711e-03, -2.2461e-02],\n",
       "          [-7.6172e-01,  1.3184e-01,  3.4180e-01,  ...,  2.1191e-01,\n",
       "           -4.1211e-01,  2.8516e-01],\n",
       "          [-2.1680e-01,  3.9453e-01,  6.5234e-01,  ..., -3.8281e-01,\n",
       "            5.5469e-01, -1.1963e-01],\n",
       "          ...,\n",
       "          [ 6.9824e-02, -3.0078e-01, -1.4355e-01,  ..., -2.1680e-01,\n",
       "            1.4746e-01, -1.2109e-01],\n",
       "          [-1.5137e-01, -1.5625e-01,  1.5039e-01,  ..., -4.2578e-01,\n",
       "            4.8218e-03,  6.4453e-02],\n",
       "          [-1.1094e+00,  6.8359e-01,  3.7109e-01,  ...,  7.6172e-02,\n",
       "           -2.8711e-01, -2.2339e-02]],\n",
       "\n",
       "         [[ 3.4668e-02,  1.3809e-03, -9.1553e-03,  ...,  5.7068e-03,\n",
       "            8.9111e-03,  1.6113e-02],\n",
       "          [ 5.7068e-03,  1.4746e-01,  1.6797e-01,  ..., -3.9844e-01,\n",
       "           -2.1973e-01, -1.8311e-02],\n",
       "          [ 3.8867e-01,  4.6094e-01, -6.8750e-01,  ..., -1.2793e-01,\n",
       "           -2.3926e-01,  1.0156e-01],\n",
       "          ...,\n",
       "          [-6.3965e-02, -5.4932e-02,  1.9043e-01,  ...,  4.4434e-02,\n",
       "           -3.3008e-01, -1.8677e-02],\n",
       "          [ 5.3516e-01,  6.9531e-01, -3.4375e-01,  ..., -3.5938e-01,\n",
       "           -6.8848e-02, -2.7539e-01],\n",
       "          [-4.9219e-01,  7.9102e-02, -2.3926e-01,  ..., -1.7383e-01,\n",
       "            1.7212e-02, -3.6523e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.4709e-02, -5.3711e-03, -2.3193e-02,  ..., -1.2451e-01,\n",
       "           -6.5918e-02, -9.0332e-02],\n",
       "          [-2.6758e-01, -1.2598e-01, -1.3184e-02,  ...,  3.8672e-01,\n",
       "           -1.3203e+00,  3.3750e+00],\n",
       "          [ 2.0996e-02,  3.3984e-01,  2.0801e-01,  ...,  3.1055e-01,\n",
       "           -7.5781e-01,  2.4062e+00],\n",
       "          ...,\n",
       "          [ 1.3086e-01, -4.1992e-01,  3.8477e-01,  ..., -8.1250e-01,\n",
       "            9.2188e-01,  2.6719e+00],\n",
       "          [ 1.2256e-01, -8.8281e-01, -2.0312e-01,  ...,  5.6250e-01,\n",
       "            2.6953e-01,  2.6719e+00],\n",
       "          [ 1.7578e-01, -7.8125e-01, -2.0605e-01,  ...,  7.7344e-01,\n",
       "            8.7500e-01,  1.0469e+00]],\n",
       "\n",
       "         [[ 1.2329e-02, -8.1177e-03, -4.3640e-03,  ..., -7.2754e-02,\n",
       "           -4.9561e-02, -1.3965e-01],\n",
       "          [-1.3867e-01,  3.7891e-01,  2.0996e-01,  ..., -6.4062e-01,\n",
       "            2.4062e+00, -2.3281e+00],\n",
       "          [-7.1094e-01,  1.3906e+00, -5.3906e-01,  ..., -2.8320e-01,\n",
       "            2.2188e+00, -1.6875e+00],\n",
       "          ...,\n",
       "          [-3.3203e-02, -1.9434e-01, -3.7109e-02,  ...,  5.4297e-01,\n",
       "            1.5391e+00, -1.5078e+00],\n",
       "          [ 2.0898e-01,  4.5312e-01, -8.7109e-01,  ...,  2.6953e-01,\n",
       "           -4.5654e-02, -1.2422e+00],\n",
       "          [-2.5195e-01, -9.2773e-02, -1.6406e-01,  ...,  2.5391e-01,\n",
       "            1.3047e+00, -1.3672e+00]],\n",
       "\n",
       "         [[ 2.9663e-02,  5.6641e-02, -1.5488e-03,  ...,  3.9453e-01,\n",
       "           -5.6641e-01, -5.8350e-02],\n",
       "          [ 3.7891e-01,  2.1484e-01, -7.1289e-02,  ..., -1.8750e-01,\n",
       "            2.7930e-01,  2.9688e-01],\n",
       "          [-3.2422e-01,  2.3438e-01, -1.4375e+00,  ...,  4.4336e-01,\n",
       "            5.4688e-02, -2.0508e-01],\n",
       "          ...,\n",
       "          [ 4.4531e-01, -2.6172e-01, -1.8945e-01,  ..., -1.2891e-01,\n",
       "            1.7266e+00, -1.2109e+00],\n",
       "          [ 4.3750e-01,  1.1865e-01, -9.7656e-03,  ...,  7.5781e-01,\n",
       "            2.4688e+00, -8.2422e-01],\n",
       "          [ 2.3633e-01, -8.0566e-02, -7.2266e-02,  ..., -9.4238e-02,\n",
       "            1.7188e-01,  3.5156e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 4.0527e-02, -1.9287e-02, -1.7700e-02,  ...,  4.7852e-02,\n",
       "           -4.3164e-01,  1.9141e-01],\n",
       "          [-4.9023e-01, -6.3672e-01,  6.1719e-01,  ...,  1.3125e+00,\n",
       "            5.1562e-01, -8.2422e-01],\n",
       "          [ 1.9531e-03,  2.8906e-01, -3.4961e-01,  ...,  1.3750e+00,\n",
       "            4.6875e-01, -7.5781e-01],\n",
       "          ...,\n",
       "          [ 5.8203e-01, -1.1816e-01,  7.1875e-01,  ...,  7.7344e-01,\n",
       "            2.6406e+00, -1.0312e+00],\n",
       "          [ 3.9307e-02,  9.7656e-04,  3.7500e-01,  ...,  8.9844e-01,\n",
       "            1.7109e+00, -1.0596e-01],\n",
       "          [ 1.0742e-01, -3.0078e-01, -2.9785e-02,  ..., -1.9727e-01,\n",
       "            2.8594e+00,  8.3203e-01]],\n",
       "\n",
       "         [[-2.5787e-03,  2.0142e-02,  2.4048e-02,  ...,  7.2266e-01,\n",
       "           -2.2363e-01, -2.5000e+00],\n",
       "          [ 7.5391e-01, -3.2031e-01, -4.3213e-02,  ..., -2.0215e-01,\n",
       "            7.1094e-01,  7.3438e+00],\n",
       "          [ 6.7188e-01, -9.3750e-02,  1.6309e-01,  ..., -6.1328e-01,\n",
       "            2.9297e-01,  9.8125e+00],\n",
       "          ...,\n",
       "          [-3.0273e-01,  1.2891e-01,  7.7344e-01,  ..., -5.6250e-01,\n",
       "            1.4453e+00,  1.0375e+01],\n",
       "          [-2.7734e-01, -4.2773e-01, -2.2070e-01,  ..., -5.6641e-02,\n",
       "           -3.1055e-01,  1.0312e+01],\n",
       "          [ 1.9336e-01,  3.0078e-01,  4.5410e-02,  ..., -2.3594e+00,\n",
       "           -2.6953e-01,  7.8750e+00]],\n",
       "\n",
       "         [[-3.3691e-02,  2.7588e-02, -1.8188e-02,  ..., -2.9297e-01,\n",
       "           -1.6797e-01,  2.5586e-01],\n",
       "          [ 1.0352e-01,  3.2617e-01,  1.1621e-01,  ...,  2.1875e+00,\n",
       "            1.6953e+00,  1.1797e+00],\n",
       "          [-3.4375e-01,  4.6484e-01,  2.5391e-01,  ...,  3.2812e+00,\n",
       "            2.0312e+00,  1.7266e+00],\n",
       "          ...,\n",
       "          [-3.8281e-01, -2.4512e-01,  1.2158e-01,  ...,  2.1406e+00,\n",
       "            1.3047e+00,  3.2812e+00],\n",
       "          [-2.2168e-01, -3.3008e-01,  5.7812e-01,  ...,  1.1875e+00,\n",
       "            6.3281e-01,  2.8125e+00],\n",
       "          [-8.9844e-02, -1.7773e-01, -2.7734e-01,  ...,  2.8594e+00,\n",
       "           -3.7109e-01,  5.6250e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 4.9472e-06,  7.6599e-03, -1.0254e-02,  ...,  1.5198e-02,\n",
       "           -7.7515e-03, -2.0142e-02],\n",
       "          [-1.2793e-01,  2.0508e-01,  1.9336e-01,  ..., -5.1953e-01,\n",
       "           -3.3984e-01,  5.0391e-01],\n",
       "          [ 3.1006e-02, -3.0078e-01,  3.0078e-01,  ...,  5.6152e-02,\n",
       "           -3.5156e-01,  4.3359e-01],\n",
       "          ...,\n",
       "          [ 6.3281e-01, -7.6172e-02,  9.9219e-01,  ..., -4.2383e-01,\n",
       "            2.9297e-01, -6.0303e-02],\n",
       "          [ 2.4316e-01,  1.5137e-02,  5.3906e-01,  ..., -5.1172e-01,\n",
       "            2.1387e-01, -5.9814e-02],\n",
       "          [ 2.5195e-01, -3.6328e-01,  1.5723e-01,  ..., -7.8516e-01,\n",
       "            1.9727e-01, -2.4609e-01]],\n",
       "\n",
       "         [[-6.9275e-03, -2.9175e-02, -3.6133e-02,  ...,  5.0537e-02,\n",
       "           -5.7373e-03, -1.0193e-02],\n",
       "          [-4.0039e-01,  8.8379e-02, -5.3906e-01,  ..., -5.5469e-01,\n",
       "            5.1514e-02,  4.8047e-01],\n",
       "          [ 1.8066e-01,  5.3125e-01, -1.4746e-01,  ..., -9.8145e-02,\n",
       "            7.7344e-01,  6.6016e-01],\n",
       "          ...,\n",
       "          [ 8.6719e-01,  2.7710e-02,  7.3828e-01,  ...,  6.2256e-02,\n",
       "           -1.6113e-01,  2.7539e-01],\n",
       "          [ 1.6602e-02, -3.5858e-03,  2.9102e-01,  ..., -5.7617e-02,\n",
       "           -2.1289e-01,  2.8516e-01],\n",
       "          [-2.2461e-01, -4.2383e-01,  1.4609e+00,  ...,  4.7266e-01,\n",
       "            8.0859e-01,  3.4570e-01]],\n",
       "\n",
       "         [[-5.8105e-02, -5.0659e-03, -1.3733e-02,  ...,  3.2959e-02,\n",
       "            2.9541e-02,  1.0681e-02],\n",
       "          [-8.7402e-02,  3.9062e-01,  2.8711e-01,  ..., -4.8242e-01,\n",
       "            8.3984e-02, -6.4392e-03],\n",
       "          [ 4.4922e-01,  5.1172e-01,  1.7578e-01,  ..., -5.9766e-01,\n",
       "           -3.1250e-01, -1.1523e-01],\n",
       "          ...,\n",
       "          [-1.8848e-01, -5.7812e-01, -2.4512e-01,  ..., -2.5977e-01,\n",
       "           -2.9883e-01, -8.5938e-01],\n",
       "          [-2.1582e-01, -4.4141e-01,  8.5449e-03,  ..., -1.0645e-01,\n",
       "           -1.1816e-01, -9.3750e-02],\n",
       "          [ 5.8203e-01, -2.0020e-01,  1.1914e-01,  ..., -2.7100e-02,\n",
       "           -3.2422e-01, -8.2422e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.1951e-03,  7.2937e-03,  6.7383e-02,  ...,  1.6235e-02,\n",
       "           -1.7334e-02, -2.3071e-02],\n",
       "          [-1.0449e-01,  1.7578e-01, -5.6641e-01,  ..., -3.2031e-01,\n",
       "            2.2461e-01,  4.4141e-01],\n",
       "          [-6.7871e-02, -1.7334e-02, -2.4805e-01,  ..., -4.9805e-01,\n",
       "           -7.6172e-01,  5.3516e-01],\n",
       "          ...,\n",
       "          [ 3.0469e-01, -5.0000e-01,  4.0820e-01,  ..., -5.9375e-01,\n",
       "           -3.0078e-01,  8.6328e-01],\n",
       "          [ 1.4160e-01, -4.2969e-01,  1.1914e-01,  ..., -4.4141e-01,\n",
       "            9.2773e-02,  1.2598e-01],\n",
       "          [ 4.2578e-01, -1.1133e-01,  5.0000e-01,  ..., -6.1768e-02,\n",
       "           -6.5918e-02,  1.7969e-01]],\n",
       "\n",
       "         [[-1.6724e-02, -1.4648e-02, -7.6904e-03,  ..., -1.0071e-02,\n",
       "            5.9204e-03, -6.3477e-03],\n",
       "          [ 4.5312e-01,  1.5723e-01,  2.9688e-01,  ..., -1.0547e+00,\n",
       "           -2.5977e-01,  2.9297e-01],\n",
       "          [ 9.9219e-01,  1.0859e+00,  3.7354e-02,  ..., -9.1406e-01,\n",
       "           -2.8442e-02,  3.7500e-01],\n",
       "          ...,\n",
       "          [-8.0078e-02,  8.2422e-01, -2.9297e-01,  ..., -3.0078e-01,\n",
       "           -2.2266e-01, -6.5430e-02],\n",
       "          [ 2.3047e-01,  5.3125e-01,  3.4424e-02,  ..., -1.8457e-01,\n",
       "            3.7305e-01, -2.5781e-01],\n",
       "          [ 2.5586e-01,  8.6328e-01,  4.0771e-02,  ..., -1.9141e+00,\n",
       "            8.3984e-01, -4.6289e-01]],\n",
       "\n",
       "         [[ 1.1963e-02, -4.6997e-03,  4.4556e-03,  ...,  1.3550e-02,\n",
       "            3.0762e-02,  6.9275e-03],\n",
       "          [-4.3945e-01, -6.0156e-01, -4.3945e-01,  ..., -2.0996e-01,\n",
       "           -2.9102e-01, -6.8848e-02],\n",
       "          [-5.7422e-01, -7.4219e-01, -4.8633e-01,  ..., -3.7305e-01,\n",
       "           -3.2031e-01, -4.3359e-01],\n",
       "          ...,\n",
       "          [ 1.8262e-01, -2.3242e-01,  3.2812e-01,  ...,  5.8594e-01,\n",
       "           -8.3203e-01, -3.2812e-01],\n",
       "          [-2.4512e-01, -6.2109e-01,  4.4922e-01,  ...,  6.2500e-02,\n",
       "           -4.9219e-01, -7.9102e-02],\n",
       "          [-1.8262e-01, -4.1406e-01, -4.9805e-02,  ...,  1.8750e-01,\n",
       "           -1.2344e+00,  8.7402e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 3.5400e-02, -1.6357e-02,  1.9287e-02,  ..., -3.8672e-01,\n",
       "            2.6562e-01, -1.7480e-01],\n",
       "          [ 4.1504e-02,  2.0605e-01,  4.1797e-01,  ..., -1.3281e+00,\n",
       "           -5.4688e-01,  3.5742e-01],\n",
       "          [-3.3008e-01, -1.1406e+00,  2.3340e-01,  ..., -1.1172e+00,\n",
       "           -1.7344e+00,  1.9922e-01],\n",
       "          ...,\n",
       "          [-5.1172e-01,  6.6016e-01, -2.6562e-01,  ..., -2.5391e-01,\n",
       "           -1.4922e+00,  8.4766e-01],\n",
       "          [-2.5391e-01,  2.9883e-01,  1.7676e-01,  ...,  1.0625e+00,\n",
       "           -1.0391e+00,  2.4219e+00],\n",
       "          [ 2.8320e-02,  6.8359e-02,  1.3477e-01,  ...,  6.0547e-01,\n",
       "           -1.1016e+00, -5.6250e-01]],\n",
       "\n",
       "         [[-3.5400e-02,  4.6082e-03, -8.6060e-03,  ..., -5.8203e-01,\n",
       "           -7.4219e-01, -2.3730e-01],\n",
       "          [-1.9531e-03,  4.5166e-02,  2.3535e-01,  ..., -3.4375e+00,\n",
       "            9.6875e-01,  1.1172e+00],\n",
       "          [-7.8125e-01, -2.7344e-01,  7.2754e-02,  ..., -2.2031e+00,\n",
       "            1.8516e+00, -3.6328e-01],\n",
       "          ...,\n",
       "          [ 5.5859e-01, -1.6406e-01, -7.2266e-02,  ...,  2.9531e+00,\n",
       "            1.7109e+00,  8.1250e-01],\n",
       "          [ 4.1406e-01,  2.0117e-01,  2.0386e-02,  ...,  2.2188e+00,\n",
       "            3.7500e+00,  1.3125e+00],\n",
       "          [-4.1016e-02,  1.6479e-02, -4.9072e-02,  ...,  2.5781e+00,\n",
       "            3.1250e+00,  1.0781e+00]],\n",
       "\n",
       "         [[ 2.3193e-02, -2.2736e-03,  1.2146e-02,  ..., -1.2109e-01,\n",
       "            7.9102e-02,  3.5938e-01],\n",
       "          [-3.4375e-01,  3.5547e-01,  1.9336e-01,  ..., -1.5000e+00,\n",
       "            2.4219e+00, -1.3516e+00],\n",
       "          [-1.1914e-01,  6.8750e-01,  2.2461e-01,  ..., -1.9375e+00,\n",
       "            3.5469e+00, -1.2988e-01],\n",
       "          ...,\n",
       "          [ 4.9561e-02,  1.2354e-01, -2.3340e-01,  ...,  3.8750e+00,\n",
       "            2.4219e+00,  2.1406e+00],\n",
       "          [-3.3789e-01, -8.4375e-01, -2.8906e-01,  ...,  2.1562e+00,\n",
       "            2.5000e+00,  5.5469e-01],\n",
       "          [ 6.6797e-01,  7.0801e-02,  1.0391e+00,  ...,  3.3750e+00,\n",
       "            3.5625e+00, -1.1328e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.3477e-02,  1.1978e-03, -2.0386e-02,  ..., -4.4922e-01,\n",
       "           -7.9102e-02,  4.0430e-01],\n",
       "          [-2.2461e-01,  6.2500e-02,  5.3223e-02,  ...,  3.8086e-01,\n",
       "           -2.9531e+00, -4.2812e+00],\n",
       "          [-6.6895e-02, -9.0625e-01,  6.8359e-01,  ...,  1.3906e+00,\n",
       "           -3.1094e+00, -3.0156e+00],\n",
       "          ...,\n",
       "          [ 6.5625e-01, -5.6396e-02, -5.7422e-01,  ...,  8.5938e-01,\n",
       "            4.3750e-01, -1.6406e+00],\n",
       "          [ 8.7500e-01,  5.0781e-01, -2.1777e-01,  ...,  1.1875e+00,\n",
       "            8.9062e-01, -2.0625e+00],\n",
       "          [ 6.3477e-03, -1.3965e-01,  5.1270e-02,  ..., -1.2188e+00,\n",
       "           -7.2656e-01,  1.6113e-01]],\n",
       "\n",
       "         [[-1.8921e-03,  1.4893e-02,  4.9561e-02,  ...,  7.2656e-01,\n",
       "            3.9062e-01, -2.3594e+00],\n",
       "          [ 2.8711e-01,  9.8145e-02, -9.7656e-02,  ...,  8.9844e-01,\n",
       "           -3.9375e+00,  4.5938e+00],\n",
       "          [ 8.0078e-02, -1.4062e-01,  1.5332e-01,  ...,  9.0234e-01,\n",
       "           -3.6875e+00,  6.0938e+00],\n",
       "          ...,\n",
       "          [ 4.0039e-01,  4.8828e-03, -1.0742e-01,  ...,  1.0625e+00,\n",
       "            2.9492e-01,  4.5938e+00],\n",
       "          [ 6.1035e-02,  5.6641e-02,  3.3203e-02,  ..., -7.3047e-01,\n",
       "            9.9609e-01,  4.5312e+00],\n",
       "          [-9.1797e-02, -3.5645e-02, -5.3906e-01,  ..., -1.4062e+00,\n",
       "           -2.1387e-01,  5.1562e+00]],\n",
       "\n",
       "         [[-4.4861e-03,  1.3123e-02,  1.3062e-02,  ...,  2.9102e-01,\n",
       "           -1.7969e-01, -1.2500e-01],\n",
       "          [ 4.9805e-01, -1.0254e-01,  1.7212e-02,  ...,  2.3750e+00,\n",
       "            5.3516e-01, -1.3438e+00],\n",
       "          [-4.2188e-01, -5.2734e-01,  2.6172e-01,  ...,  1.6250e+00,\n",
       "            1.0938e+00, -1.4453e+00],\n",
       "          ...,\n",
       "          [ 3.5156e-02,  2.3535e-01, -7.3242e-03,  ..., -1.6641e+00,\n",
       "            1.8203e+00,  2.8711e-01],\n",
       "          [ 3.8477e-01, -1.2891e-01, -3.9258e-01,  ..., -9.4922e-01,\n",
       "            2.3594e+00,  2.5938e+00],\n",
       "          [ 1.9775e-02,  7.9590e-02, -6.9336e-02,  ...,  8.7500e-01,\n",
       "            1.5234e+00,  2.2344e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-1.0071e-02, -6.5308e-03, -1.2146e-02,  ...,  1.1108e-02,\n",
       "            1.5381e-02, -4.9438e-03],\n",
       "          [ 6.4062e-01, -1.0859e+00, -8.2422e-01,  ...,  1.2578e+00,\n",
       "            3.7305e-01, -6.9922e-01],\n",
       "          [ 4.8633e-01, -5.8594e-01, -2.3438e-01,  ...,  8.8672e-01,\n",
       "            4.6289e-01, -2.0020e-01],\n",
       "          ...,\n",
       "          [ 5.3906e-01,  2.4872e-03, -4.1797e-01,  ...,  2.8906e-01,\n",
       "           -1.7480e-01, -9.2188e-01],\n",
       "          [ 1.1182e-01, -1.4453e-01, -2.4805e-01,  ..., -5.0781e-01,\n",
       "            9.2773e-02, -3.0078e-01],\n",
       "          [ 7.8906e-01,  7.3047e-01,  3.2031e-01,  ..., -7.0703e-01,\n",
       "            2.2266e-01, -1.6797e-01]],\n",
       "\n",
       "         [[-3.7109e-02, -1.7822e-02,  3.5553e-03,  ..., -1.0071e-02,\n",
       "            4.2725e-03,  1.2634e-02],\n",
       "          [-4.8242e-01,  2.9907e-02,  5.5078e-01,  ..., -9.8145e-02,\n",
       "           -4.6094e-01,  4.2773e-01],\n",
       "          [ 2.9688e-01, -6.6797e-01, -1.8848e-01,  ...,  6.1719e-01,\n",
       "            7.3047e-01,  6.8750e-01],\n",
       "          ...,\n",
       "          [-5.6250e-01,  3.4424e-02, -1.2598e-01,  ...,  7.9688e-01,\n",
       "           -6.0938e-01, -5.1953e-01],\n",
       "          [ 7.5781e-01, -5.9766e-01,  4.4531e-01,  ...,  5.5078e-01,\n",
       "            1.0078e+00,  6.2109e-01],\n",
       "          [ 7.0703e-01,  4.1748e-02,  9.2969e-01,  ...,  4.6484e-01,\n",
       "           -1.0234e+00,  6.9141e-01]],\n",
       "\n",
       "         [[ 1.1063e-03, -8.3542e-04, -1.8311e-02,  ..., -1.4404e-02,\n",
       "           -1.9653e-02, -2.1362e-03],\n",
       "          [ 1.6797e-01,  3.9648e-01,  6.8359e-02,  ...,  1.1816e-01,\n",
       "            5.4199e-02,  4.9316e-02],\n",
       "          [-2.1191e-01,  3.6133e-01,  6.1768e-02,  ...,  4.6094e-01,\n",
       "            1.6504e-01, -1.9922e-01],\n",
       "          ...,\n",
       "          [ 1.3086e-01, -3.9453e-01, -4.4727e-01,  ...,  4.1211e-01,\n",
       "           -4.3945e-01,  6.2988e-02],\n",
       "          [ 2.9688e-01, -1.7480e-01, -2.0605e-01,  ...,  2.4609e-01,\n",
       "           -2.8711e-01, -1.5430e-01],\n",
       "          [-2.5781e-01,  2.3242e-01, -5.8594e-01,  ..., -5.8350e-02,\n",
       "           -6.5625e-01,  3.4570e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.1902e-02,  1.7212e-02, -1.2207e-02,  ..., -2.6978e-02,\n",
       "           -1.4267e-03,  1.0010e-02],\n",
       "          [ 3.9844e-01, -1.0938e-01,  1.9165e-02,  ..., -2.7344e-01,\n",
       "           -9.8267e-03, -4.0820e-01],\n",
       "          [-2.5977e-01,  1.8164e-01, -1.0559e-02,  ..., -4.1992e-01,\n",
       "            3.3789e-01, -9.0234e-01],\n",
       "          ...,\n",
       "          [-2.2461e-02,  5.2734e-01, -3.3008e-01,  ...,  6.5234e-01,\n",
       "           -1.6016e-01, -4.7656e-01],\n",
       "          [-5.0049e-02,  4.6484e-01, -7.7344e-01,  ..., -1.5625e-01,\n",
       "            6.4941e-02, -1.1914e-01],\n",
       "          [ 1.1328e+00,  4.7852e-01, -1.0391e+00,  ...,  1.0469e+00,\n",
       "            1.4551e-01,  5.8984e-01]],\n",
       "\n",
       "         [[-1.9226e-03, -1.4465e-02, -9.0332e-03,  ...,  6.5308e-03,\n",
       "            2.4109e-03,  7.2632e-03],\n",
       "          [ 3.0078e-01,  1.4941e-01,  7.3047e-01,  ...,  3.9062e-01,\n",
       "            3.0664e-01, -3.0859e-01],\n",
       "          [ 1.0449e-01,  8.1641e-01,  2.7930e-01,  ..., -1.7773e-01,\n",
       "            2.0312e-01, -3.4375e-01],\n",
       "          ...,\n",
       "          [-2.7734e-01, -3.7109e-02,  3.2812e-01,  ...,  9.1016e-01,\n",
       "            1.4941e-01,  4.4922e-01],\n",
       "          [ 9.5312e-01,  1.1719e-01,  8.2031e-01,  ...,  7.0312e-01,\n",
       "            1.1621e-01,  2.3828e-01],\n",
       "          [ 5.4297e-01, -8.3496e-02,  1.2266e+00,  ...,  1.5234e+00,\n",
       "           -8.5449e-03, -3.6523e-01]],\n",
       "\n",
       "         [[-1.2878e-02,  1.1353e-02, -1.1377e-01,  ...,  5.4169e-04,\n",
       "           -3.4668e-02, -4.2114e-03],\n",
       "          [ 4.2578e-01, -5.6250e-01,  4.3164e-01,  ...,  4.2236e-02,\n",
       "            3.4766e-01,  7.5000e-01],\n",
       "          [ 6.3672e-01, -6.5625e-01,  2.8711e-01,  ...,  4.8828e-01,\n",
       "            1.3379e-01,  6.0938e-01],\n",
       "          ...,\n",
       "          [-9.5703e-01, -5.5469e-01,  3.8281e-01,  ...,  4.0625e-01,\n",
       "           -2.5391e-01, -5.7812e-01],\n",
       "          [-2.4902e-01, -2.8711e-01,  3.6914e-01,  ...,  8.8379e-02,\n",
       "            8.3496e-02, -1.0559e-02],\n",
       "          [ 9.0234e-01, -1.9043e-01, -3.4766e-01,  ...,  4.7119e-02,\n",
       "            1.1875e+00, -1.1250e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-2.2217e-02,  1.0559e-02, -1.5015e-02,  ...,  5.9375e-01,\n",
       "           -3.5156e-01, -6.8359e-01],\n",
       "          [-2.1582e-01, -7.8125e-02,  4.2969e-02,  ..., -4.0625e-01,\n",
       "           -1.3516e+00, -2.3145e-01],\n",
       "          [-2.3535e-01, -5.0391e-01, -4.9609e-01,  ..., -1.6875e+00,\n",
       "           -2.0469e+00, -3.2715e-02],\n",
       "          ...,\n",
       "          [ 2.0117e-01, -2.4902e-01, -6.5625e-01,  ..., -2.5000e+00,\n",
       "            4.3555e-01, -1.1953e+00],\n",
       "          [ 2.3438e-01,  1.0254e-01, -2.7344e-01,  ..., -3.0469e+00,\n",
       "            8.8281e-01,  7.0703e-01],\n",
       "          [ 2.0020e-01, -1.3574e-01, -7.5195e-02,  ..., -3.2969e+00,\n",
       "           -5.3516e-01, -1.9922e+00]],\n",
       "\n",
       "         [[-9.3994e-03, -8.1177e-03, -1.4526e-02,  ...,  6.0303e-02,\n",
       "           -1.2207e-01, -8.3008e-02],\n",
       "          [ 2.8516e-01,  8.6719e-01,  1.3672e-02,  ..., -6.8359e-02,\n",
       "           -1.1562e+00,  2.7734e-01],\n",
       "          [ 2.6367e-01,  1.9531e-01, -8.8867e-02,  ..., -4.9805e-02,\n",
       "           -1.6641e+00,  6.5625e-01],\n",
       "          ...,\n",
       "          [-8.7500e-01,  5.0391e-01,  2.4219e-01,  ..., -1.1250e+00,\n",
       "           -2.7500e+00,  4.7852e-01],\n",
       "          [ 4.6484e-01,  6.5625e-01, -5.0781e-01,  ...,  2.8711e-01,\n",
       "           -2.1250e+00,  8.9453e-01],\n",
       "          [ 4.0625e-01,  3.3203e-01, -1.2988e-01,  ..., -6.8359e-01,\n",
       "           -5.4297e-01,  1.1875e+00]],\n",
       "\n",
       "         [[ 1.3733e-02,  3.8574e-02, -7.5073e-03,  ..., -5.6250e-01,\n",
       "            1.1963e-01,  2.5977e-01],\n",
       "          [ 1.0156e-01, -3.9453e-01,  1.7090e-02,  ...,  1.9297e+00,\n",
       "            4.7266e-01, -2.1094e+00],\n",
       "          [ 4.7656e-01, -6.5625e-01,  6.2891e-01,  ...,  2.3594e+00,\n",
       "           -1.3438e+00, -1.7188e+00],\n",
       "          ...,\n",
       "          [-3.7109e-01,  3.3789e-01,  3.3008e-01,  ...,  2.0625e+00,\n",
       "           -6.9141e-01,  2.3926e-01],\n",
       "          [-7.0703e-01,  8.0469e-01,  6.8359e-03,  ...,  1.8984e+00,\n",
       "           -1.7891e+00, -2.6875e+00],\n",
       "          [ 4.7266e-01,  1.6309e-01, -4.3945e-01,  ...,  8.2031e-01,\n",
       "           -6.2500e-01, -3.4531e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.4191e-03,  9.4604e-03,  1.9989e-03,  ...,  7.3730e-02,\n",
       "           -3.6621e-02, -6.6406e-02],\n",
       "          [ 2.9492e-01,  4.8828e-02,  5.0781e-01,  ..., -2.4805e-01,\n",
       "            2.9297e-01,  5.1953e-01],\n",
       "          [ 5.1562e-01,  3.2812e-01,  2.8564e-02,  ...,  3.2422e-01,\n",
       "            1.0234e+00,  4.1992e-01],\n",
       "          ...,\n",
       "          [-9.0625e-01,  2.3281e+00, -5.9570e-02,  ..., -3.3447e-02,\n",
       "           -8.5156e-01, -2.8516e-01],\n",
       "          [ 1.7188e+00,  3.9648e-01, -1.3047e+00,  ...,  6.4062e-01,\n",
       "           -1.3965e-01, -1.0234e+00],\n",
       "          [-2.8320e-01,  2.3242e-01,  1.0059e-01,  ..., -1.2734e+00,\n",
       "           -1.3828e+00,  1.2969e+00]],\n",
       "\n",
       "         [[ 4.1260e-02,  1.1230e-02, -7.8735e-03,  ...,  1.9824e-01,\n",
       "           -1.1279e-01,  5.9326e-02],\n",
       "          [-3.5889e-02,  1.6016e-01, -4.5312e-01,  ..., -1.0469e+00,\n",
       "            8.5156e-01, -1.8555e-01],\n",
       "          [-2.5513e-02,  4.1211e-01, -4.2773e-01,  ..., -2.0410e-01,\n",
       "            7.3047e-01, -5.6250e-01],\n",
       "          ...,\n",
       "          [-1.5332e-01, -2.4316e-01, -2.6758e-01,  ..., -9.7266e-01,\n",
       "           -1.2031e+00,  8.3984e-01],\n",
       "          [ 4.5703e-01, -6.6406e-01, -6.9824e-02,  ..., -2.1094e+00,\n",
       "           -8.9844e-01,  1.5391e+00],\n",
       "          [ 3.3984e-01,  8.0566e-02,  6.1719e-01,  ...,  2.5781e+00,\n",
       "           -2.5156e+00,  1.7734e+00]],\n",
       "\n",
       "         [[ 1.2451e-02,  4.6082e-03, -2.5513e-02,  ...,  2.2461e-01,\n",
       "           -1.2891e+00,  2.4414e-01],\n",
       "          [-5.7812e-01,  3.2031e-01, -3.0273e-01,  ..., -1.1328e+00,\n",
       "           -1.5000e+00, -2.5938e+00],\n",
       "          [-2.0605e-01, -1.5991e-02,  1.3574e-01,  ..., -2.2188e+00,\n",
       "           -1.1250e+00, -1.7109e+00],\n",
       "          ...,\n",
       "          [-1.6602e-01, -5.4688e-02,  4.7852e-01,  ..., -1.5938e+00,\n",
       "            1.2656e+00, -3.0156e+00],\n",
       "          [ 3.1445e-01,  2.1484e-01, -1.9531e-01,  ..., -1.1172e+00,\n",
       "            1.9688e+00, -3.3438e+00],\n",
       "          [ 3.9844e-01,  1.4160e-01, -3.6719e-01,  ...,  5.8594e-02,\n",
       "            1.3281e+00, -6.0156e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-4.5471e-03,  3.0859e-01,  5.1880e-03,  ...,  3.4180e-02,\n",
       "            2.2827e-02, -1.1047e-02],\n",
       "          [-1.9062e+00, -1.4922e+00,  2.5586e-01,  ...,  2.8906e-01,\n",
       "           -7.9102e-02,  1.5918e-01],\n",
       "          [-1.0312e+00, -1.1094e+00,  5.3516e-01,  ...,  1.2891e-01,\n",
       "           -2.4609e-01,  1.2109e-01],\n",
       "          ...,\n",
       "          [-3.4180e-01, -1.5703e+00,  6.0156e-01,  ...,  3.1250e-01,\n",
       "            3.9453e-01, -1.7456e-02],\n",
       "          [-1.0156e+00, -1.0234e+00, -5.5908e-02,  ...,  1.2598e-01,\n",
       "            1.1953e+00, -5.5469e-01],\n",
       "          [-1.2422e+00, -1.4922e+00,  2.0312e-01,  ..., -2.1240e-02,\n",
       "            9.1406e-01, -7.0703e-01]],\n",
       "\n",
       "         [[ 1.2329e-02,  2.0386e-02, -2.6367e-02,  ..., -1.9836e-03,\n",
       "           -5.3406e-03, -2.9297e-01],\n",
       "          [-3.9258e-01,  3.9062e-02,  1.0986e-01,  ..., -1.4551e-01,\n",
       "            1.3733e-02,  1.3594e+00],\n",
       "          [-2.4512e-01,  2.5391e-01, -8.1055e-02,  ..., -2.4805e-01,\n",
       "           -6.1328e-01,  1.5000e+00],\n",
       "          ...,\n",
       "          [ 9.6484e-01, -9.1309e-02, -6.3281e-01,  ..., -1.5703e+00,\n",
       "            2.1387e-01,  1.6641e+00],\n",
       "          [ 8.3594e-01, -2.4316e-01, -1.8555e-02,  ..., -5.8203e-01,\n",
       "            1.9336e-01,  1.3984e+00],\n",
       "          [ 1.0234e+00, -1.0859e+00, -8.7109e-01,  ..., -6.5625e-01,\n",
       "            3.0859e-01,  1.1641e+00]],\n",
       "\n",
       "         [[-5.3101e-03,  2.8076e-03, -8.6670e-03,  ...,  9.7656e-03,\n",
       "           -3.0823e-03, -1.7822e-02],\n",
       "          [ 2.5391e-01, -5.5469e-01,  3.4766e-01,  ..., -7.0312e-02,\n",
       "            8.6719e-01, -3.1055e-01],\n",
       "          [ 3.9062e-01, -6.1035e-02,  2.1191e-01,  ...,  3.6133e-02,\n",
       "            1.4453e-01,  8.8501e-03],\n",
       "          ...,\n",
       "          [ 5.5469e-01, -5.7422e-01, -1.5781e+00,  ...,  1.1182e-01,\n",
       "            3.6719e-01,  5.0000e-01],\n",
       "          [ 7.8125e-01,  4.5312e-01, -7.6172e-01,  ...,  1.7188e-01,\n",
       "            4.4141e-01,  2.0996e-01],\n",
       "          [-4.3457e-02,  1.3489e-02, -5.3906e-01,  ..., -3.1641e-01,\n",
       "            2.9907e-02,  2.1875e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1973e-02,  6.0425e-03,  5.1575e-03,  ..., -3.3203e-02,\n",
       "           -9.4604e-03, -1.3611e-02],\n",
       "          [-9.9609e-02, -8.0469e-01,  2.3926e-01,  ...,  1.1841e-02,\n",
       "            1.6113e-01, -5.1562e-01],\n",
       "          [ 2.0898e-01,  7.2327e-03,  2.8125e-01,  ..., -3.0469e-01,\n",
       "           -7.1289e-02, -3.0664e-01],\n",
       "          ...,\n",
       "          [ 7.6172e-02, -8.3594e-01,  4.1260e-02,  ...,  6.7188e-01,\n",
       "            5.1172e-01,  3.3789e-01],\n",
       "          [ 8.0469e-01, -6.3281e-01, -4.7070e-01,  ...,  6.2109e-01,\n",
       "           -6.9922e-01,  5.7031e-01],\n",
       "          [ 7.8906e-01, -3.9258e-01,  2.2363e-01,  ...,  1.5234e+00,\n",
       "            6.3672e-01, -4.5117e-01]],\n",
       "\n",
       "         [[-1.7700e-02, -1.0742e-02, -1.1475e-02,  ..., -8.9111e-03,\n",
       "           -3.7695e-01, -6.5994e-04],\n",
       "          [ 8.1543e-02,  7.3730e-02,  2.9297e-01,  ..., -2.2363e-01,\n",
       "            1.7188e-01, -3.2812e-01],\n",
       "          [-1.4844e-01,  4.9072e-02,  3.1836e-01,  ..., -1.9922e-01,\n",
       "            2.8906e-01, -3.1055e-01],\n",
       "          ...,\n",
       "          [ 3.6133e-01, -7.4219e-01,  4.0039e-01,  ..., -5.1172e-01,\n",
       "            1.2031e+00,  5.9375e-01],\n",
       "          [-6.3965e-02, -4.4531e-01,  5.1562e-01,  ..., -5.4688e-01,\n",
       "            9.4141e-01,  6.4453e-01],\n",
       "          [-3.4180e-01, -4.4727e-01,  6.9922e-01,  ...,  1.3379e-01,\n",
       "            8.5547e-01,  3.7500e-01]],\n",
       "\n",
       "         [[-1.8845e-03, -1.0620e-02,  2.2461e-02,  ..., -1.0071e-02,\n",
       "           -5.8594e-03,  1.0071e-02],\n",
       "          [ 3.5156e-01,  7.6660e-02, -6.2891e-01,  ..., -5.0391e-01,\n",
       "           -2.1289e-01, -9.1406e-01],\n",
       "          [ 4.2188e-01,  5.1953e-01, -2.4609e-01,  ...,  8.8867e-02,\n",
       "            6.0059e-02, -4.0625e-01],\n",
       "          ...,\n",
       "          [-4.6387e-02,  8.3496e-02,  2.3535e-01,  ...,  6.7578e-01,\n",
       "            2.7539e-01, -2.9883e-01],\n",
       "          [ 5.7422e-01, -8.3496e-02, -2.3535e-01,  ..., -4.9316e-02,\n",
       "           -4.1992e-01,  3.0664e-01],\n",
       "          [ 3.0151e-02,  8.0078e-01, -5.0391e-01,  ...,  3.7305e-01,\n",
       "           -5.3906e-01, -1.5137e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.2715e-02,  3.0273e-02, -3.0762e-02,  ...,  3.5742e-01,\n",
       "            7.4219e-02,  1.4844e-01],\n",
       "          [-3.6914e-01, -3.5938e-01,  5.8594e-01,  ..., -3.7500e-01,\n",
       "            2.4688e+00,  1.0391e+00],\n",
       "          [-2.7148e-01,  6.7578e-01,  3.0859e-01,  ..., -1.0703e+00,\n",
       "            1.0938e+00, -1.7969e-01],\n",
       "          ...,\n",
       "          [-3.7109e-01,  1.0547e-01, -4.9023e-01,  ..., -8.8281e-01,\n",
       "           -2.1582e-01, -2.1250e+00],\n",
       "          [ 2.1606e-02,  3.9648e-01,  5.2734e-01,  ..., -9.1016e-01,\n",
       "           -1.0498e-01, -5.7031e-01],\n",
       "          [ 1.9434e-01, -5.3125e-01, -6.1328e-01,  ...,  1.8281e+00,\n",
       "            1.0312e+00, -3.6719e-01]],\n",
       "\n",
       "         [[-2.3438e-02,  5.9814e-03, -1.3855e-02,  ...,  3.5938e-01,\n",
       "            4.1992e-01,  6.2891e-01],\n",
       "          [-1.3867e-01,  4.1406e-01,  3.7695e-01,  ..., -9.5312e-01,\n",
       "            8.3594e-01, -2.1875e+00],\n",
       "          [-4.6680e-01,  4.2188e-01,  1.2793e-01,  ..., -1.2578e+00,\n",
       "            1.1484e+00, -2.4844e+00],\n",
       "          ...,\n",
       "          [-2.2461e-01,  1.5039e-01, -9.2969e-01,  ..., -2.0469e+00,\n",
       "           -1.0859e+00, -7.8125e-01],\n",
       "          [-1.0791e-01,  1.1523e-01, -5.1172e-01,  ..., -1.0078e+00,\n",
       "           -6.9922e-01,  4.3750e-01],\n",
       "          [-3.9258e-01,  2.7344e-01, -5.8203e-01,  ..., -2.0625e+00,\n",
       "           -6.0547e-01,  1.3359e+00]],\n",
       "\n",
       "         [[-3.1281e-04,  3.0151e-02,  1.2573e-02,  ..., -1.6724e-02,\n",
       "            1.7969e-01,  3.0273e-01],\n",
       "          [-1.2695e-01,  2.9688e-01,  4.0039e-01,  ...,  1.0938e+00,\n",
       "           -5.0391e-01, -3.3594e-01],\n",
       "          [-3.6328e-01,  3.0273e-01,  5.4297e-01,  ...,  1.3125e+00,\n",
       "           -3.0664e-01,  1.2656e+00],\n",
       "          ...,\n",
       "          [ 3.5938e-01,  2.6367e-01, -7.7344e-01,  ..., -3.7891e-01,\n",
       "            1.2793e-01,  1.5000e+00],\n",
       "          [ 7.5781e-01,  2.4707e-01, -2.2949e-02,  ...,  8.8672e-01,\n",
       "            6.1719e-01,  7.0703e-01],\n",
       "          [ 7.4707e-02,  6.6797e-01,  6.2500e-02,  ...,  8.0859e-01,\n",
       "           -1.3281e+00,  1.3672e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.5757e-02,  6.9580e-03, -1.1230e-02,  ..., -3.8574e-02,\n",
       "           -2.2559e-01, -4.9609e-01],\n",
       "          [-7.0312e-02, -6.0547e-01, -4.6289e-01,  ..., -1.0498e-01,\n",
       "            1.1562e+00, -3.5742e-01],\n",
       "          [-1.0625e+00, -1.1641e+00,  2.4170e-02,  ..., -4.4531e-01,\n",
       "            1.7031e+00, -3.5742e-01],\n",
       "          ...,\n",
       "          [-2.2852e-01, -4.7266e-01, -2.7734e-01,  ..., -1.0547e+00,\n",
       "            3.9844e-01,  4.4727e-01],\n",
       "          [-7.5000e-01,  4.8438e-01,  2.5586e-01,  ..., -3.8281e-01,\n",
       "            6.5918e-02,  1.1484e+00],\n",
       "          [ 1.1133e-01,  1.3672e-02, -3.2812e-01,  ..., -4.4336e-01,\n",
       "           -6.7188e-01,  1.6113e-01]],\n",
       "\n",
       "         [[ 2.0996e-02, -2.8076e-02, -3.2227e-02,  ...,  7.4219e-02,\n",
       "           -1.6797e+00,  1.4648e-02],\n",
       "          [-1.8457e-01, -2.7539e-01,  9.6191e-02,  ..., -1.4062e+00,\n",
       "            2.9688e+00,  1.3203e+00],\n",
       "          [ 6.7383e-02, -2.2217e-02,  3.6719e-01,  ..., -2.0469e+00,\n",
       "            4.0938e+00,  7.6953e-01],\n",
       "          ...,\n",
       "          [-1.8164e-01,  1.0498e-01, -4.4531e-01,  ...,  5.2344e-01,\n",
       "            4.5938e+00,  1.0234e+00],\n",
       "          [-1.2793e-01, -1.0449e-01, -1.6211e-01,  ..., -2.0156e+00,\n",
       "            4.5938e+00,  2.8125e+00],\n",
       "          [-3.0762e-02, -2.7344e-01, -3.3203e-01,  ..., -1.9531e+00,\n",
       "            4.6250e+00,  2.1719e+00]],\n",
       "\n",
       "         [[ 5.3406e-03,  2.1973e-02,  1.1230e-02,  ...,  2.0801e-01,\n",
       "            5.8203e-01,  1.7090e-02],\n",
       "          [ 9.5215e-02,  2.3828e-01,  2.0898e-01,  ..., -1.6328e+00,\n",
       "           -6.7188e-01, -1.3281e+00],\n",
       "          [ 1.1768e-01, -2.0215e-01,  5.4443e-02,  ..., -2.2812e+00,\n",
       "           -3.6914e-01, -1.3203e+00],\n",
       "          ...,\n",
       "          [-1.8066e-01,  1.1670e-01, -1.4844e-01,  ..., -1.1875e+00,\n",
       "           -2.0000e+00, -1.8516e+00],\n",
       "          [-4.2188e-01,  1.9531e-03,  6.6016e-01,  ..., -2.0264e-02,\n",
       "           -1.9844e+00, -2.6250e+00],\n",
       "          [-1.3477e-01,  1.6016e-01,  6.6406e-01,  ..., -6.4453e-02,\n",
       "           -2.2500e+00, -3.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.8555e-02, -1.5625e-02,  9.8267e-03,  ..., -5.4016e-03,\n",
       "           -6.5308e-03, -9.3994e-03],\n",
       "          [-3.7500e-01,  9.8047e-01,  6.0547e-01,  ...,  5.6641e-01,\n",
       "            1.1719e+00,  6.4844e-01],\n",
       "          [-4.2578e-01,  5.8984e-01,  3.0078e-01,  ...,  1.0938e+00,\n",
       "            7.0312e-01,  5.8984e-01],\n",
       "          ...,\n",
       "          [ 5.7812e-01,  1.7285e-01, -2.7734e-01,  ...,  3.2227e-01,\n",
       "           -7.6660e-02,  4.4141e-01],\n",
       "          [ 6.7188e-01,  4.5117e-01, -1.7773e-01,  ..., -5.2490e-02,\n",
       "           -2.6758e-01,  3.0859e-01],\n",
       "          [-5.7031e-01, -5.8594e-01, -1.2695e-01,  ...,  2.3535e-01,\n",
       "           -3.1055e-01, -1.0059e-01]],\n",
       "\n",
       "         [[ 1.3916e-02, -7.7515e-03,  8.1787e-03,  ...,  3.6926e-03,\n",
       "           -7.8735e-03, -1.0742e-02],\n",
       "          [ 7.9297e-01, -2.7539e-01,  6.5430e-02,  ...,  6.5918e-02,\n",
       "            3.7842e-02, -3.3984e-01],\n",
       "          [ 9.4141e-01,  2.5269e-02,  9.0332e-02,  ...,  3.5400e-02,\n",
       "            3.6133e-02,  2.3438e-01],\n",
       "          ...,\n",
       "          [ 5.8984e-01, -8.4839e-03, -1.0303e-01,  ...,  4.8828e-01,\n",
       "           -5.6641e-01,  4.0039e-01],\n",
       "          [ 6.5625e-01,  5.7812e-01, -3.9844e-01,  ..., -2.0142e-02,\n",
       "            4.2188e-01,  6.0059e-02],\n",
       "          [ 1.3867e-01,  2.1094e-01, -5.7031e-01,  ..., -8.1250e-01,\n",
       "            9.7656e-02,  1.5625e-01]],\n",
       "\n",
       "         [[-6.4087e-04,  1.2329e-02,  1.5625e-02,  ...,  1.5869e-02,\n",
       "           -1.8066e-02, -4.3457e-02],\n",
       "          [-2.8125e-01, -2.5195e-01, -8.3984e-02,  ...,  3.3203e-01,\n",
       "           -2.5195e-01,  1.3672e+00],\n",
       "          [ 6.0156e-01, -1.7480e-01,  1.5332e-01,  ..., -1.0303e-01,\n",
       "            5.7812e-01,  7.8906e-01],\n",
       "          ...,\n",
       "          [ 9.8828e-01, -1.0303e-01, -3.6328e-01,  ..., -2.1484e-01,\n",
       "            2.5586e-01,  3.5547e-01],\n",
       "          [ 8.3594e-01, -2.7344e-01,  1.9688e+00,  ..., -3.2422e-01,\n",
       "           -8.0469e-01,  5.8203e-01],\n",
       "          [ 4.2773e-01, -1.9336e-01,  4.1797e-01,  ...,  2.9883e-01,\n",
       "           -1.1133e-01, -3.5352e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-4.6387e-03,  9.0332e-03,  3.5248e-03,  ..., -1.3489e-02,\n",
       "           -1.1780e-02, -1.0925e-02],\n",
       "          [ 1.7578e-01,  9.3359e-01,  7.1094e-01,  ..., -5.5078e-01,\n",
       "            4.7656e-01, -6.3281e-01],\n",
       "          [-4.7852e-01,  6.4453e-01,  9.0820e-02,  ..., -4.7461e-01,\n",
       "            5.5859e-01,  1.0449e-01],\n",
       "          ...,\n",
       "          [ 5.2344e-01,  6.8750e-01,  3.8574e-02,  ..., -1.6895e-01,\n",
       "            6.3672e-01, -3.4375e-01],\n",
       "          [ 5.5908e-02,  9.0234e-01, -2.9297e-01,  ...,  6.9531e-01,\n",
       "            6.2109e-01, -2.4316e-01],\n",
       "          [ 2.2461e-01,  1.6504e-01, -8.6719e-01,  ..., -1.7578e-01,\n",
       "            4.7266e-01, -2.5757e-02]],\n",
       "\n",
       "         [[-3.0212e-03,  8.3618e-03, -3.7109e-02,  ...,  1.4526e-02,\n",
       "            8.6060e-03, -1.2329e-02],\n",
       "          [-2.2070e-01, -4.1406e-01, -1.0000e+00,  ..., -9.7266e-01,\n",
       "            3.5156e-01, -2.6367e-01],\n",
       "          [-1.6699e-01, -1.0547e-01, -1.6895e-01,  ..., -5.1562e-01,\n",
       "            7.1289e-02,  3.2715e-02],\n",
       "          ...,\n",
       "          [-4.0283e-02, -4.1602e-01, -4.4922e-01,  ...,  3.1055e-01,\n",
       "            8.1055e-02, -5.1953e-01],\n",
       "          [ 7.7734e-01, -7.9688e-01,  3.8867e-01,  ..., -3.0273e-01,\n",
       "           -3.0273e-01, -7.7344e-01],\n",
       "          [ 2.8906e-01,  2.3242e-01,  4.6094e-01,  ...,  1.7383e-01,\n",
       "            5.8203e-01, -2.4023e-01]],\n",
       "\n",
       "         [[ 5.6763e-03,  2.5269e-02, -8.2397e-03,  ...,  6.5231e-04,\n",
       "            1.6968e-02,  6.0425e-03],\n",
       "          [-3.3008e-01,  2.2559e-01, -1.9165e-02,  ..., -9.5703e-02,\n",
       "           -4.8047e-01, -4.4727e-01],\n",
       "          [-8.5449e-02,  6.2500e-01,  3.1836e-01,  ...,  1.3379e-01,\n",
       "           -5.1172e-01, -3.3594e-01],\n",
       "          ...,\n",
       "          [ 2.9688e-01, -1.7773e-01,  5.7031e-01,  ...,  8.0859e-01,\n",
       "            2.9297e-01, -3.0469e-01],\n",
       "          [ 1.3477e-01,  1.0107e-01, -4.5654e-02,  ...,  3.5547e-01,\n",
       "            1.2451e-01,  4.0430e-01],\n",
       "          [-1.0156e+00,  1.1816e-01, -3.2812e-01,  ...,  6.1328e-01,\n",
       "           -4.9219e-01, -8.3203e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.3447e-02, -6.4392e-03,  1.6235e-02,  ..., -2.3242e-01,\n",
       "            5.1562e-01,  7.4219e-01],\n",
       "          [-3.1055e-01, -1.0254e-01,  2.7930e-01,  ...,  1.1572e-01,\n",
       "           -1.1016e+00, -1.0938e+00],\n",
       "          [-7.7734e-01, -3.0469e-01, -3.1641e-01,  ...,  4.8438e-01,\n",
       "           -5.0391e-01, -6.3672e-01],\n",
       "          ...,\n",
       "          [-3.8086e-02,  9.5703e-02, -9.0625e-01,  ...,  8.7891e-01,\n",
       "           -9.0625e-01, -2.0781e+00],\n",
       "          [ 1.3438e+00,  3.1250e-01, -3.7891e-01,  ...,  4.2188e-01,\n",
       "           -1.2344e+00, -2.5000e+00],\n",
       "          [-1.3281e-01, -2.6562e-01, -6.7871e-02,  ..., -1.4746e-01,\n",
       "           -1.1719e+00,  4.6680e-01]],\n",
       "\n",
       "         [[-1.4496e-04,  1.6479e-03, -1.8555e-02,  ..., -2.9883e-01,\n",
       "            5.1758e-02,  1.4893e-02],\n",
       "          [-4.2969e-02,  1.1562e+00,  3.0469e-01,  ...,  2.1973e-01,\n",
       "           -1.2695e-01,  3.1445e-01],\n",
       "          [-3.2031e-01,  3.7891e-01, -1.0547e-01,  ...,  3.4570e-01,\n",
       "            3.7305e-01,  2.5195e-01],\n",
       "          ...,\n",
       "          [ 3.5938e-01, -3.8672e-01, -5.5078e-01,  ...,  1.1094e+00,\n",
       "            6.9922e-01,  2.9492e-01],\n",
       "          [-2.0703e-01, -3.5938e-01,  3.1250e-01,  ...,  5.5859e-01,\n",
       "           -1.3965e-01, -4.6484e-01],\n",
       "          [ 3.8281e-01,  3.1445e-01,  9.0234e-01,  ..., -9.4531e-01,\n",
       "            5.4688e-01,  2.1094e+00]],\n",
       "\n",
       "         [[ 1.1063e-03, -1.0925e-02, -4.8218e-03,  ..., -6.7383e-02,\n",
       "            1.8652e-01, -2.5757e-02],\n",
       "          [ 3.8330e-02,  9.6094e-01, -8.5156e-01,  ..., -6.2500e-01,\n",
       "           -1.1328e-01, -1.5078e+00],\n",
       "          [-2.9102e-01,  7.1484e-01, -4.5312e-01,  ...,  3.1055e-01,\n",
       "           -3.0859e-01, -1.3750e+00],\n",
       "          ...,\n",
       "          [-3.2617e-01, -4.2969e-01,  8.9844e-02,  ...,  1.5391e+00,\n",
       "           -8.5938e-01,  7.2266e-01],\n",
       "          [-4.4531e-01, -5.5469e-01, -4.4336e-01,  ...,  2.4512e-01,\n",
       "           -1.0078e+00, -6.3672e-01],\n",
       "          [-1.5820e-01,  1.2891e-01, -2.3828e-01,  ..., -1.7891e+00,\n",
       "           -5.2344e-01, -4.0234e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 9.7046e-03, -4.6387e-03,  1.1169e-02,  ..., -1.4258e-01,\n",
       "           -6.9336e-02,  4.5166e-02],\n",
       "          [ 1.7188e-01, -6.9141e-01,  8.5938e-02,  ..., -1.4453e-01,\n",
       "           -2.1973e-01,  2.8198e-02],\n",
       "          [-3.7891e-01,  9.1797e-01,  5.6641e-01,  ...,  6.8750e-01,\n",
       "           -1.0391e+00, -8.7500e-01],\n",
       "          ...,\n",
       "          [-4.1211e-01, -1.0156e+00,  2.5391e-01,  ...,  1.3672e-01,\n",
       "           -9.7656e-01, -4.2969e-01],\n",
       "          [ 1.8848e-01, -3.2031e-01, -4.4531e-01,  ..., -2.9102e-01,\n",
       "           -6.9922e-01, -7.3828e-01],\n",
       "          [-4.7461e-01,  9.2188e-01, -1.8906e+00,  ..., -1.7422e+00,\n",
       "           -6.4844e-01, -1.0254e-02]],\n",
       "\n",
       "         [[-3.8300e-03, -2.9373e-04, -8.9722e-03,  ...,  7.1289e-02,\n",
       "           -1.6895e-01, -2.5586e-01],\n",
       "          [-2.0410e-01, -1.4258e-01,  2.8906e-01,  ..., -1.7285e-01,\n",
       "            5.8203e-01, -1.9141e-01],\n",
       "          [ 1.3281e-01, -3.9453e-01,  8.2031e-01,  ..., -5.3125e-01,\n",
       "           -2.2363e-01,  3.1055e-01],\n",
       "          ...,\n",
       "          [-7.1094e-01,  1.1426e-01,  1.6699e-01,  ...,  1.4609e+00,\n",
       "            1.7266e+00,  3.8086e-01],\n",
       "          [ 2.1191e-01,  2.8516e-01,  2.5000e-01,  ..., -1.3965e-01,\n",
       "            1.2188e+00,  9.9609e-01],\n",
       "          [ 9.1016e-01,  0.0000e+00,  5.5859e-01,  ..., -9.5312e-01,\n",
       "            7.4219e-01,  7.3047e-01]],\n",
       "\n",
       "         [[-3.0670e-03,  1.2085e-02, -1.9653e-02,  ...,  2.2949e-01,\n",
       "            2.0605e-01,  8.7891e-02],\n",
       "          [ 3.8672e-01,  1.0156e-01,  6.2500e-01,  ...,  2.3047e-01,\n",
       "           -1.3855e-02,  1.9766e+00],\n",
       "          [ 9.3750e-02, -5.2734e-01,  9.7656e-01,  ..., -1.1963e-01,\n",
       "           -7.5391e-01,  2.9375e+00],\n",
       "          ...,\n",
       "          [-6.8750e-01, -6.5625e-01, -3.5547e-01,  ...,  7.8125e-01,\n",
       "           -2.2559e-01,  1.0938e-01],\n",
       "          [-3.1055e-01, -4.4141e-01, -8.3984e-01,  ...,  9.1406e-01,\n",
       "            2.3145e-01, -9.1797e-01],\n",
       "          [ 1.8164e-01,  5.2734e-01, -3.0078e-01,  ...,  2.1719e+00,\n",
       "            2.4062e+00,  1.9141e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 2.9449e-03,  5.4626e-03,  4.4250e-03,  ..., -2.3346e-03,\n",
       "           -5.9204e-03,  5.1270e-03],\n",
       "          [ 8.3203e-01,  5.1953e-01,  3.3008e-01,  ..., -3.2422e-01,\n",
       "           -4.5703e-01, -1.6309e-01],\n",
       "          [ 6.4062e-01,  3.4570e-01,  7.5391e-01,  ..., -6.5430e-02,\n",
       "           -6.5918e-02, -2.5513e-02],\n",
       "          ...,\n",
       "          [-6.3672e-01,  5.3711e-02,  1.4648e-01,  ...,  8.7891e-02,\n",
       "            1.4141e+00,  1.1406e+00],\n",
       "          [-1.4062e-01, -2.6367e-01, -2.0386e-02,  ...,  2.5195e-01,\n",
       "            6.2500e-02,  6.0938e-01],\n",
       "          [-4.0527e-02, -3.7695e-01, -6.9922e-01,  ...,  1.4844e-01,\n",
       "           -4.1016e-01, -1.8555e-01]],\n",
       "\n",
       "         [[ 1.5030e-03, -2.3315e-02,  2.3041e-03,  ..., -9.7046e-03,\n",
       "           -2.5879e-02, -1.0071e-02],\n",
       "          [ 4.5703e-01,  7.4609e-01, -4.6289e-01,  ..., -6.0547e-02,\n",
       "            9.2773e-03,  3.7500e-01],\n",
       "          [-2.5146e-02, -2.3340e-01,  6.7969e-01,  ..., -1.0938e-01,\n",
       "            5.5847e-03, -8.5449e-03],\n",
       "          ...,\n",
       "          [ 6.3672e-01, -3.3008e-01,  5.1953e-01,  ...,  7.8613e-02,\n",
       "            6.9922e-01, -4.1992e-01],\n",
       "          [-2.6953e-01,  6.3281e-01,  5.8984e-01,  ..., -1.6699e-01,\n",
       "            6.5234e-01, -3.9062e-01],\n",
       "          [ 2.0215e-01, -1.0254e-01, -1.1328e-01,  ...,  2.4658e-02,\n",
       "           -2.4902e-01, -7.0801e-02]],\n",
       "\n",
       "         [[ 8.3008e-03,  5.0354e-03, -1.5640e-03,  ...,  4.9561e-02,\n",
       "           -9.0408e-04, -1.3428e-03],\n",
       "          [-2.6953e-01,  2.6953e-01, -5.5859e-01,  ..., -2.8711e-01,\n",
       "           -7.9297e-01,  2.1973e-01],\n",
       "          [-2.0117e-01,  3.8672e-01, -2.2559e-01,  ..., -5.5859e-01,\n",
       "           -5.7812e-01,  1.8066e-01],\n",
       "          ...,\n",
       "          [ 1.1963e-01,  3.6133e-02, -1.8359e-01,  ..., -3.0396e-02,\n",
       "           -4.5117e-01, -2.5391e-01],\n",
       "          [ 3.9648e-01, -2.2266e-01,  9.1309e-02,  ...,  1.9165e-02,\n",
       "           -3.2422e-01, -3.0859e-01],\n",
       "          [ 4.1797e-01, -1.7334e-02, -3.0664e-01,  ..., -9.6191e-02,\n",
       "           -1.0625e+00, -4.2578e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.5215e-03, -2.9907e-03,  7.4219e-02,  ...,  4.3030e-03,\n",
       "           -5.1270e-02, -1.3611e-02],\n",
       "          [ 2.0410e-01,  6.2109e-01, -4.1992e-02,  ..., -9.5703e-02,\n",
       "            2.4219e-01,  4.1016e-01],\n",
       "          [ 2.9492e-01,  2.8320e-01, -5.1953e-01,  ..., -1.1572e-01,\n",
       "            9.8047e-01,  1.2988e-01],\n",
       "          ...,\n",
       "          [ 9.0234e-01, -1.6211e-01, -4.7852e-01,  ..., -1.6602e-01,\n",
       "            6.6406e-01,  7.6562e-01],\n",
       "          [ 4.9414e-01,  1.6797e-01,  1.9629e-01,  ..., -8.1250e-01,\n",
       "            2.2656e-01,  6.3281e-01],\n",
       "          [ 8.8281e-01, -1.7480e-01, -1.8457e-01,  ..., -7.4219e-01,\n",
       "            1.0889e-01,  2.3828e-01]],\n",
       "\n",
       "         [[-5.7068e-03,  4.8218e-03, -1.2512e-02,  ...,  2.6489e-02,\n",
       "           -8.0566e-03,  3.6621e-03],\n",
       "          [ 4.6387e-02, -6.7578e-01, -2.0874e-02,  ...,  1.0781e+00,\n",
       "            8.7500e-01,  1.8262e-01],\n",
       "          [ 2.3730e-01, -6.7969e-01, -3.8086e-01,  ...,  1.0391e+00,\n",
       "            6.9141e-01,  4.3945e-01],\n",
       "          ...,\n",
       "          [ 2.9297e-01, -1.1094e+00,  4.9609e-01,  ...,  8.5449e-02,\n",
       "           -1.0234e+00,  7.0703e-01],\n",
       "          [-1.4832e-02, -8.7891e-01,  4.5898e-01,  ...,  7.5000e-01,\n",
       "           -1.2031e+00,  2.2461e-01],\n",
       "          [ 3.8867e-01, -4.4727e-01, -3.7695e-01,  ...,  3.5547e-01,\n",
       "            8.8867e-02,  4.0234e-01]],\n",
       "\n",
       "         [[ 8.7280e-03, -5.3711e-03,  7.1411e-03,  ..., -5.1575e-03,\n",
       "            1.5259e-02, -3.0518e-03],\n",
       "          [ 7.7344e-01,  2.2949e-01,  9.0234e-01,  ...,  6.6406e-01,\n",
       "            7.3047e-01, -6.2109e-01],\n",
       "          [ 4.9609e-01, -9.0820e-02,  6.0547e-01,  ...,  6.5234e-01,\n",
       "            4.7852e-01, -3.3984e-01],\n",
       "          ...,\n",
       "          [-4.6094e-01,  9.0332e-02,  1.4062e-01,  ..., -4.7852e-01,\n",
       "           -7.8125e-01,  7.1094e-01],\n",
       "          [-9.1016e-01,  4.9414e-01,  7.5391e-01,  ..., -4.0039e-01,\n",
       "           -5.6250e-01,  9.8828e-01],\n",
       "          [-6.7578e-01,  2.3340e-01, -1.7188e-01,  ..., -1.4355e-01,\n",
       "           -6.7578e-01,  8.6719e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-6.8848e-02,  2.3804e-02, -7.7515e-03,  ..., -1.8799e-02,\n",
       "            2.3145e-01, -1.3828e+00],\n",
       "          [ 9.0820e-02, -2.1582e-01, -2.5391e-02,  ...,  1.2266e+00,\n",
       "           -1.1523e-01,  7.0000e+00],\n",
       "          [ 8.0566e-02,  7.7637e-02,  1.6602e-01,  ...,  1.4688e+00,\n",
       "           -6.2109e-01,  6.5000e+00],\n",
       "          ...,\n",
       "          [-2.8516e-01,  4.4189e-02, -8.0078e-02,  ...,  1.7031e+00,\n",
       "           -1.4219e+00,  4.3438e+00],\n",
       "          [ 6.0156e-01,  1.2695e-01, -7.0312e-02,  ...,  1.7422e+00,\n",
       "           -2.0781e+00,  3.4375e+00],\n",
       "          [ 5.4688e-01, -1.9629e-01,  2.1094e-01,  ...,  1.0312e+00,\n",
       "            1.5793e-03,  3.4844e+00]],\n",
       "\n",
       "         [[ 4.2725e-02, -1.6174e-03,  3.4424e-02,  ..., -1.3672e-01,\n",
       "            1.8672e+00, -8.4473e-02],\n",
       "          [ 9.4727e-02, -1.8750e-01, -2.9297e-02,  ..., -6.0547e-01,\n",
       "           -4.2500e+00,  1.7578e+00],\n",
       "          [-1.1719e-02, -1.4062e-01,  3.1055e-01,  ..., -3.2422e-01,\n",
       "           -5.4062e+00,  6.6016e-01],\n",
       "          ...,\n",
       "          [-2.3193e-02, -5.8594e-03,  2.8906e-01,  ...,  6.7969e-01,\n",
       "           -5.5312e+00, -1.0625e+00],\n",
       "          [-2.2070e-01, -1.5625e-01,  1.6602e-01,  ...,  1.3574e-01,\n",
       "           -5.4688e+00, -7.3828e-01],\n",
       "          [ 5.1270e-03,  1.4355e-01, -3.0469e-01,  ...,  3.1494e-02,\n",
       "           -3.6562e+00, -2.3750e+00]],\n",
       "\n",
       "         [[-9.3994e-03, -1.9043e-02,  1.0620e-02,  ...,  6.8848e-02,\n",
       "           -5.3906e-01,  7.5000e-01],\n",
       "          [ 1.2695e-02, -6.7188e-01,  7.5195e-02,  ...,  1.5391e+00,\n",
       "            3.8477e-01, -1.4453e+00],\n",
       "          [ 9.7656e-01, -4.2578e-01,  9.5703e-02,  ...,  1.2109e-01,\n",
       "            4.9219e-01, -1.3984e+00],\n",
       "          ...,\n",
       "          [ 6.2500e-02, -8.7109e-01, -5.5469e-01,  ...,  9.0234e-01,\n",
       "            3.6914e-01, -2.3750e+00],\n",
       "          [ 3.2812e-01, -2.4805e-01, -4.6094e-01,  ...,  1.5078e+00,\n",
       "           -1.1797e+00, -2.1094e+00],\n",
       "          [ 8.0078e-01,  1.8945e-01,  7.5781e-01,  ..., -3.4180e-01,\n",
       "           -1.7285e-01, -2.1719e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.5391e-02, -4.0283e-02,  2.9297e-03,  ..., -6.0938e-01,\n",
       "           -6.9824e-02, -3.2812e-01],\n",
       "          [ 2.8125e-01,  6.8359e-02, -4.6484e-01,  ...,  1.3125e+00,\n",
       "           -4.4688e+00, -2.7344e+00],\n",
       "          [ 4.7266e-01,  1.6602e-01, -3.5889e-02,  ...,  2.1719e+00,\n",
       "           -3.6094e+00, -1.6797e+00],\n",
       "          ...,\n",
       "          [-4.0771e-02,  9.7656e-02,  4.1211e-01,  ...,  3.2500e+00,\n",
       "           -8.1641e-01, -1.3125e+00],\n",
       "          [-2.1387e-01,  4.4922e-02,  4.2969e-01,  ...,  1.2344e+00,\n",
       "           -1.2031e+00, -4.8438e-01],\n",
       "          [ 1.4551e-01, -9.2773e-02,  1.5430e-01,  ...,  1.1797e+00,\n",
       "           -6.4062e-01,  7.9297e-01]],\n",
       "\n",
       "         [[ 3.4424e-02,  7.1716e-03, -9.7046e-03,  ...,  3.7109e-01,\n",
       "           -1.1094e+00,  9.0234e-01],\n",
       "          [-1.6211e-01, -2.1191e-01,  1.1182e-01,  ..., -5.2188e+00,\n",
       "            1.8281e+00, -2.5156e+00],\n",
       "          [-3.7109e-01, -4.2969e-01, -9.1406e-01,  ..., -4.3750e+00,\n",
       "            1.5547e+00, -2.9531e+00],\n",
       "          ...,\n",
       "          [ 3.4180e-02, -7.7148e-02,  2.3828e-01,  ..., -9.0234e-01,\n",
       "            5.5938e+00,  4.1211e-01],\n",
       "          [-1.9043e-02,  2.0996e-02,  4.1406e-01,  ...,  1.0938e+00,\n",
       "            3.2656e+00, -2.0781e+00],\n",
       "          [ 1.9434e-01, -2.2852e-01, -3.7109e-01,  ...,  3.5352e-01,\n",
       "            1.2656e+00,  1.0781e+00]],\n",
       "\n",
       "         [[-6.6223e-03, -1.0925e-02, -2.4261e-03,  ..., -1.8066e-01,\n",
       "           -4.7461e-01,  1.2266e+00],\n",
       "          [ 3.9062e-01,  7.4219e-02, -5.2734e-02,  ...,  2.5000e+00,\n",
       "           -1.5625e+00, -3.4531e+00],\n",
       "          [-6.8359e-02,  3.0859e-01,  1.6602e-02,  ...,  2.2969e+00,\n",
       "           -4.8828e-01, -3.4375e+00],\n",
       "          ...,\n",
       "          [ 3.9795e-02,  3.5352e-01,  1.5137e-01,  ..., -6.7188e-01,\n",
       "            2.0469e+00, -6.3125e+00],\n",
       "          [-2.3145e-01, -5.5469e-01,  9.6094e-01,  ..., -9.8438e-01,\n",
       "            9.9219e-01, -5.4062e+00],\n",
       "          [-1.6016e-01,  3.3984e-01, -3.2031e-01,  ...,  8.5547e-01,\n",
       "            1.3203e+00, -3.3594e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 5.3406e-03, -1.1658e-02, -5.8594e-03,  ...,  5.2795e-03,\n",
       "            3.8300e-03, -1.4267e-03],\n",
       "          [-8.3496e-02,  4.2383e-01,  4.8828e-02,  ...,  8.5547e-01,\n",
       "            4.7607e-02,  5.1172e-01],\n",
       "          [-2.8516e-01,  7.1094e-01,  4.1406e-01,  ...,  7.6953e-01,\n",
       "            1.5991e-02,  5.2734e-01],\n",
       "          ...,\n",
       "          [-1.7676e-01,  3.6621e-02,  1.1621e-01,  ..., -6.5234e-01,\n",
       "            3.4961e-01,  8.7891e-02],\n",
       "          [-4.2969e-01, -1.5137e-01,  1.2598e-01,  ..., -7.8125e-01,\n",
       "            6.8359e-01,  1.2354e-01],\n",
       "          [ 2.1582e-01,  1.0234e+00,  1.1484e+00,  ...,  7.2266e-01,\n",
       "           -2.3145e-01,  4.6289e-01]],\n",
       "\n",
       "         [[-1.4038e-02, -1.2970e-03, -8.1787e-03,  ...,  2.4658e-02,\n",
       "            2.5482e-03, -6.0425e-03],\n",
       "          [-8.4766e-01,  4.1602e-01, -1.0352e-01,  ..., -3.2227e-01,\n",
       "            6.8750e-01, -7.5391e-01],\n",
       "          [-3.5742e-01, -4.8340e-02, -1.0156e+00,  ..., -3.6719e-01,\n",
       "            2.4609e-01, -2.6562e-01],\n",
       "          ...,\n",
       "          [-2.5781e-01,  5.7422e-01,  6.2988e-02,  ...,  9.6191e-02,\n",
       "            6.4453e-01, -1.2793e-01],\n",
       "          [-2.6367e-01,  3.2031e-01, -3.8281e-01,  ..., -4.0820e-01,\n",
       "            8.4766e-01, -1.9531e-01],\n",
       "          [ 5.2185e-03,  4.9805e-01,  2.1240e-02,  ..., -2.7008e-03,\n",
       "            3.6523e-01,  2.0020e-01]],\n",
       "\n",
       "         [[-3.3722e-03,  4.0283e-03,  1.5411e-03,  ...,  1.4801e-03,\n",
       "            2.4109e-03,  7.8735e-03],\n",
       "          [ 5.5078e-01,  9.6484e-01, -2.0312e-01,  ...,  1.4877e-03,\n",
       "            5.4688e-01,  1.9238e-01],\n",
       "          [ 8.0078e-01,  7.4219e-01, -1.4941e-01,  ...,  2.1680e-01,\n",
       "           -3.9844e-01,  1.6309e-01],\n",
       "          ...,\n",
       "          [-1.9336e-01, -5.9375e-01,  1.0625e+00,  ..., -6.6406e-01,\n",
       "           -2.4316e-01,  4.7266e-01],\n",
       "          [-9.9609e-01,  5.7031e-01,  4.8242e-01,  ..., -6.8750e-01,\n",
       "           -4.1211e-01,  8.2520e-02],\n",
       "          [-4.2188e-01,  4.8438e-01,  1.6309e-01,  ..., -7.1875e-01,\n",
       "           -2.6953e-01, -2.9492e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.6926e-03, -1.0834e-03, -7.5989e-03,  ..., -3.6774e-03,\n",
       "           -1.8845e-03, -4.2419e-03],\n",
       "          [ 1.6406e-01, -7.4219e-02,  3.2617e-01,  ..., -6.4453e-01,\n",
       "           -3.8672e-01,  1.1641e+00],\n",
       "          [-9.0820e-02, -3.4766e-01, -2.2949e-01,  ..., -5.4297e-01,\n",
       "           -2.6001e-02,  9.2578e-01],\n",
       "          ...,\n",
       "          [ 3.8086e-02,  3.8086e-02, -1.5820e-01,  ...,  7.8516e-01,\n",
       "           -1.1865e-01, -2.7539e-01],\n",
       "          [ 8.8281e-01,  3.1836e-01,  1.0791e-01,  ...,  3.9648e-01,\n",
       "            5.7129e-02,  7.7148e-02],\n",
       "          [-1.0596e-01,  2.5781e-01,  7.1484e-01,  ...,  6.4941e-02,\n",
       "           -2.3828e-01,  5.0049e-02]],\n",
       "\n",
       "         [[ 1.3428e-03, -5.1575e-03,  8.0566e-03,  ...,  5.3711e-03,\n",
       "           -5.3101e-03, -2.0874e-02],\n",
       "          [ 7.6172e-01,  2.9883e-01, -2.6172e-01,  ..., -6.1279e-02,\n",
       "            8.1250e-01,  5.6152e-02],\n",
       "          [ 5.0293e-02, -4.7266e-01, -2.6758e-01,  ...,  3.5645e-02,\n",
       "            4.6680e-01,  1.6602e-01],\n",
       "          ...,\n",
       "          [-1.0547e+00,  1.1172e+00,  2.5977e-01,  ..., -6.4062e-01,\n",
       "            5.8594e-01,  1.3125e+00],\n",
       "          [-1.2578e+00,  6.6016e-01, -5.6641e-01,  ..., -1.1328e-01,\n",
       "            4.4141e-01,  1.0859e+00],\n",
       "          [-8.0566e-02,  3.8086e-01,  1.5137e-01,  ..., -7.5781e-01,\n",
       "            6.2500e-01,  4.0430e-01]],\n",
       "\n",
       "         [[ 2.6093e-03,  2.6398e-03,  3.7537e-03,  ..., -1.0010e-02,\n",
       "           -2.3041e-03,  1.4725e-03],\n",
       "          [-3.0273e-01, -1.0938e+00,  1.2500e-01,  ...,  1.2451e-01,\n",
       "           -7.3438e-01,  5.2246e-02],\n",
       "          [-7.0703e-01, -1.2578e+00,  2.0801e-01,  ...,  5.3516e-01,\n",
       "           -4.2969e-01, -2.0605e-01],\n",
       "          ...,\n",
       "          [ 1.7500e+00,  6.6797e-01, -1.2188e+00,  ..., -8.3203e-01,\n",
       "           -2.4609e-01, -3.7109e-01],\n",
       "          [ 1.2656e+00,  5.6250e-01, -5.4297e-01,  ..., -2.0703e-01,\n",
       "           -1.4844e-01, -2.2339e-02],\n",
       "          [ 1.5234e+00, -2.1729e-02, -9.7656e-01,  ..., -1.9922e-01,\n",
       "           -1.9336e-01,  1.5039e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-5.0659e-03, -1.3184e-02, -6.3782e-03,  ...,  2.0312e-01,\n",
       "           -1.0400e-01, -2.0996e-01],\n",
       "          [-4.7852e-02, -7.5684e-02,  7.6172e-02,  ..., -1.6641e+00,\n",
       "            1.6172e+00,  2.8438e+00],\n",
       "          [ 5.4688e-01,  1.6602e-01,  3.4180e-01,  ..., -3.4180e-01,\n",
       "            1.7344e+00,  2.4844e+00],\n",
       "          ...,\n",
       "          [-2.4512e-01,  2.5195e-01, -2.5000e-01,  ...,  2.2344e+00,\n",
       "            1.3047e+00,  1.4062e+00],\n",
       "          [-2.2070e-01,  2.2266e-01, -8.3496e-02,  ...,  1.8828e+00,\n",
       "            1.3906e+00,  9.7656e-01],\n",
       "          [ 1.2109e-01,  2.7930e-01,  5.3711e-02,  ...,  2.9531e+00,\n",
       "            8.2422e-01,  2.5000e+00]],\n",
       "\n",
       "         [[ 2.9297e-02, -7.7515e-03,  4.0283e-03,  ...,  3.1445e-01,\n",
       "            1.0859e+00, -3.3203e-01],\n",
       "          [ 7.2266e-02, -1.8750e-01,  6.5234e-01,  ..., -3.0156e+00,\n",
       "           -3.9219e+00,  1.1016e+00],\n",
       "          [ 2.6953e-01, -2.1875e-01,  2.8125e-01,  ..., -1.8516e+00,\n",
       "           -4.2812e+00,  7.3828e-01],\n",
       "          ...,\n",
       "          [ 1.9531e-01,  1.1377e-01, -3.9648e-01,  ..., -3.1406e+00,\n",
       "           -3.5781e+00,  1.1172e+00],\n",
       "          [ 7.7148e-02,  1.9727e-01, -3.4180e-01,  ..., -4.5625e+00,\n",
       "           -4.1562e+00,  3.0078e-01],\n",
       "          [-2.0996e-02,  1.2939e-02,  7.8125e-03,  ..., -2.9062e+00,\n",
       "           -3.4688e+00, -5.3125e-01]],\n",
       "\n",
       "         [[-1.5625e-02, -8.4839e-03, -7.0801e-03,  ...,  4.5117e-01,\n",
       "            3.5156e-01, -3.8672e-01],\n",
       "          [-3.3203e-02, -4.7266e-01, -1.4746e-01,  ..., -1.6250e+00,\n",
       "            1.0132e-02, -3.3398e-01],\n",
       "          [-7.0312e-02,  7.4219e-01, -4.3359e-01,  ..., -1.1094e+00,\n",
       "           -5.7422e-01,  4.9805e-02],\n",
       "          ...,\n",
       "          [-4.1406e-01,  8.4375e-01,  8.3594e-01,  ..., -2.2656e+00,\n",
       "           -4.5703e-01,  1.0986e-01],\n",
       "          [-5.6250e-01,  9.3750e-02,  9.3750e-01,  ..., -1.8828e+00,\n",
       "           -1.9219e+00,  6.6797e-01],\n",
       "          [-1.6504e-01,  2.0703e-01,  2.7148e-01,  ..., -7.4219e-01,\n",
       "           -9.5703e-01,  3.3691e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.3367e-02,  1.0132e-02,  3.3875e-03,  ...,  1.7578e+00,\n",
       "           -3.0469e-01, -3.6328e-01],\n",
       "          [ 4.1406e-01, -4.1602e-01,  3.2617e-01,  ..., -4.9688e+00,\n",
       "           -6.6797e-01,  6.6797e-01],\n",
       "          [ 2.8320e-01, -2.5586e-01, -2.0215e-01,  ..., -5.7812e+00,\n",
       "            1.1035e-01,  8.2812e-01],\n",
       "          ...,\n",
       "          [-3.1445e-01, -2.3145e-01, -1.4258e-01,  ..., -6.4062e+00,\n",
       "           -7.4219e-01,  1.1016e+00],\n",
       "          [ 4.2969e-02, -3.9062e-01,  3.4180e-01,  ..., -6.5938e+00,\n",
       "            1.8203e+00,  1.2578e+00],\n",
       "          [ 8.3008e-02,  4.0283e-02, -2.4023e-01,  ..., -6.0938e+00,\n",
       "            7.1484e-01, -2.9297e-02]],\n",
       "\n",
       "         [[ 2.0020e-02, -1.8433e-02,  2.4414e-03,  ..., -9.8438e-01,\n",
       "           -3.2031e-01, -1.0986e-01],\n",
       "          [-5.7031e-01, -1.5820e-01,  5.9375e-01,  ...,  5.9375e+00,\n",
       "           -7.4609e-01,  2.3906e+00],\n",
       "          [-7.8516e-01,  3.1641e-01, -3.7891e-01,  ...,  6.9062e+00,\n",
       "           -3.8086e-01,  1.7344e+00],\n",
       "          ...,\n",
       "          [-1.6406e-01, -1.6211e-01, -2.9541e-02,  ...,  7.8438e+00,\n",
       "           -1.5000e+00,  2.3594e+00],\n",
       "          [-2.4414e-01,  6.5918e-02, -1.0791e-01,  ...,  6.8125e+00,\n",
       "           -1.0781e+00,  1.7422e+00],\n",
       "          [-2.4902e-01, -5.1172e-01,  6.5625e-01,  ...,  7.6250e+00,\n",
       "           -1.6016e+00, -1.2266e+00]],\n",
       "\n",
       "         [[ 2.6611e-02,  8.4839e-03,  2.0996e-02,  ...,  2.2705e-02,\n",
       "            1.1328e+00, -2.5146e-02],\n",
       "          [-1.8262e-01,  4.0430e-01,  6.7969e-01,  ...,  1.3281e+00,\n",
       "           -3.0469e+00,  1.6484e+00],\n",
       "          [ 2.9297e-02, -2.9297e-01,  3.9453e-01,  ...,  9.9609e-01,\n",
       "           -2.9844e+00,  2.7969e+00],\n",
       "          ...,\n",
       "          [ 4.1602e-01, -7.9102e-02, -1.1719e-02,  ...,  9.0234e-01,\n",
       "           -1.5625e-01,  4.0625e+00],\n",
       "          [-3.4180e-01,  2.8809e-02,  4.4141e-01,  ...,  1.2354e-01,\n",
       "           -1.0234e+00,  1.5391e+00],\n",
       "          [-3.8281e-01,  5.8203e-01,  2.5781e-01,  ..., -2.1484e-01,\n",
       "           -1.7090e-01,  2.4219e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.2268e-02, -1.3489e-02, -6.8970e-03,  ..., -1.3489e-02,\n",
       "           -1.7944e-02,  1.1292e-02],\n",
       "          [-2.2949e-01,  8.9355e-02, -2.6367e-01,  ..., -4.1602e-01,\n",
       "           -6.6016e-01, -2.9883e-01],\n",
       "          [-4.1406e-01,  3.6621e-02, -2.4707e-01,  ..., -4.4336e-01,\n",
       "           -1.7871e-01,  9.5703e-02],\n",
       "          ...,\n",
       "          [ 1.3672e-01,  1.6016e-01,  2.1094e-01,  ...,  1.9238e-01,\n",
       "           -3.2031e-01,  4.4434e-02],\n",
       "          [ 1.2695e-01, -4.5654e-02,  7.1484e-01,  ..., -5.6250e-01,\n",
       "            1.2146e-02,  6.9141e-01],\n",
       "          [-6.7383e-02,  1.0010e-01,  5.2344e-01,  ..., -5.4688e-01,\n",
       "            6.9141e-01,  2.1191e-01]],\n",
       "\n",
       "         [[ 7.0496e-03,  2.0386e-02, -1.6022e-03,  ...,  9.8877e-03,\n",
       "           -1.0132e-02,  1.0559e-02],\n",
       "          [-8.3008e-02, -8.6328e-01, -4.0625e-01,  ...,  1.7773e-01,\n",
       "            2.5195e-01,  4.6680e-01],\n",
       "          [-5.2734e-01, -2.8906e-01, -5.6641e-01,  ..., -3.3691e-02,\n",
       "           -4.6875e-01,  1.8555e-01],\n",
       "          ...,\n",
       "          [ 5.2490e-02,  3.9648e-01,  4.5117e-01,  ...,  3.2471e-02,\n",
       "            3.4180e-01,  1.4648e-01],\n",
       "          [ 8.3203e-01,  6.4844e-01,  1.2109e+00,  ..., -5.3125e-01,\n",
       "            6.8750e-01, -8.2031e-02],\n",
       "          [-1.2256e-01, -2.5195e-01,  5.1562e-01,  ..., -6.8750e-01,\n",
       "            7.1289e-02,  5.5859e-01]],\n",
       "\n",
       "         [[ 1.7334e-02,  1.6479e-03, -6.7749e-03,  ..., -8.2520e-02,\n",
       "            1.1921e-04, -6.9580e-03],\n",
       "          [-1.1016e+00,  1.9531e-01, -1.6113e-01,  ..., -1.1016e+00,\n",
       "            3.7109e-01,  2.0752e-02],\n",
       "          [-3.5742e-01, -4.7852e-02, -6.2891e-01,  ..., -4.2188e-01,\n",
       "            5.9375e-01,  8.3496e-02],\n",
       "          ...,\n",
       "          [-8.5938e-01, -1.5820e-01,  3.9648e-01,  ..., -7.6172e-01,\n",
       "            5.9814e-02, -2.7930e-01],\n",
       "          [-7.6172e-01, -1.7700e-02,  1.0469e+00,  ..., -8.3984e-01,\n",
       "            3.1055e-01, -6.5625e-01],\n",
       "          [ 1.6797e-01, -5.0000e-01, -4.4922e-01,  ..., -1.9434e-01,\n",
       "           -6.2109e-01, -1.0859e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5869e-02, -7.8735e-03, -3.1738e-03,  ...,  5.2795e-03,\n",
       "            2.1210e-03,  6.3171e-03],\n",
       "          [ 6.7969e-01, -9.1797e-02,  2.0020e-01,  ..., -3.3691e-02,\n",
       "            1.2656e+00, -1.2891e-01],\n",
       "          [ 2.4414e-01,  5.6641e-01, -1.9653e-02,  ..., -4.4531e-01,\n",
       "            6.3672e-01, -3.9258e-01],\n",
       "          ...,\n",
       "          [ 6.9531e-01,  3.8867e-01, -5.0293e-02,  ..., -2.0508e-01,\n",
       "            2.4219e-01,  6.2109e-01],\n",
       "          [ 5.7422e-01,  1.1797e+00,  2.2656e-01,  ...,  7.3242e-02,\n",
       "            6.3672e-01,  1.9141e-01],\n",
       "          [ 5.2344e-01, -1.8848e-01, -3.6719e-01,  ...,  6.1328e-01,\n",
       "           -1.0107e-01,  1.8262e-01]],\n",
       "\n",
       "         [[-1.2878e-02, -1.0010e-02, -5.8899e-03,  ..., -2.3804e-03,\n",
       "            3.8147e-03, -1.6235e-02],\n",
       "          [-5.8594e-01,  6.7578e-01, -3.6719e-01,  ...,  4.1406e-01,\n",
       "           -2.0117e-01,  1.5039e-01],\n",
       "          [ 7.0312e-02,  9.2578e-01, -3.4912e-02,  ...,  3.3203e-02,\n",
       "            4.3335e-03,  4.6094e-01],\n",
       "          ...,\n",
       "          [-3.7598e-02,  4.3164e-01,  1.6235e-02,  ...,  1.9629e-01,\n",
       "            4.1602e-01,  4.2188e-01],\n",
       "          [ 6.2500e-01,  8.0078e-01, -5.2734e-01,  ...,  2.9102e-01,\n",
       "            2.4121e-01,  3.3789e-01],\n",
       "          [-2.8320e-01,  2.9297e-01,  1.0391e+00,  ..., -9.4238e-02,\n",
       "           -4.4727e-01, -7.2266e-01]],\n",
       "\n",
       "         [[-1.1414e-02, -7.0190e-03, -8.9844e-02,  ...,  1.9989e-03,\n",
       "            5.1117e-04, -4.1504e-03],\n",
       "          [ 2.0508e-01, -6.2500e-01,  4.8633e-01,  ..., -2.8906e-01,\n",
       "            2.0703e-01,  6.6797e-01],\n",
       "          [ 1.9336e-01, -5.5078e-01,  6.2109e-01,  ..., -4.4727e-01,\n",
       "           -2.4414e-01,  4.3750e-01],\n",
       "          ...,\n",
       "          [-6.2500e-01, -6.4062e-01,  1.9688e+00,  ..., -1.9844e+00,\n",
       "            3.0975e-03, -4.7656e-01],\n",
       "          [-7.0703e-01, -1.2158e-01,  1.8281e+00,  ..., -1.4688e+00,\n",
       "            5.9766e-01, -3.2422e-01],\n",
       "          [-1.8750e-01, -9.2578e-01,  1.2812e+00,  ..., -8.9453e-01,\n",
       "           -7.4219e-02,  3.6328e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 8.1177e-03, -4.6997e-03,  1.5198e-02,  ..., -2.7539e-01,\n",
       "            3.9062e-02, -5.0000e-01],\n",
       "          [-3.6914e-01,  8.7109e-01,  4.5898e-01,  ..., -9.0234e-01,\n",
       "           -5.7812e-01,  5.5469e-01],\n",
       "          [ 9.2188e-01,  4.7461e-01,  4.2188e-01,  ...,  1.7188e-01,\n",
       "            2.0996e-02,  4.9805e-01],\n",
       "          ...,\n",
       "          [-9.2188e-01,  2.9297e-01,  1.1094e+00,  ...,  7.6953e-01,\n",
       "            2.0312e+00, -1.6602e-01],\n",
       "          [-1.1016e+00, -1.8164e-01,  7.6562e-01,  ...,  1.6406e+00,\n",
       "            2.0625e+00, -3.9795e-02],\n",
       "          [-8.5938e-02,  1.1172e+00,  4.8828e-01,  ...,  2.7188e+00,\n",
       "            1.6953e+00,  6.6016e-01]],\n",
       "\n",
       "         [[-1.6724e-02, -9.3994e-03,  1.9287e-02,  ...,  3.1128e-02,\n",
       "            3.0469e-01, -5.0781e-02],\n",
       "          [-1.1250e+00, -4.5703e-01, -2.3438e-01,  ...,  3.8867e-01,\n",
       "            6.4453e-01, -4.0625e-01],\n",
       "          [-4.1602e-01,  3.1250e-01, -3.4180e-01,  ...,  1.1182e-01,\n",
       "            7.8516e-01, -9.9219e-01],\n",
       "          ...,\n",
       "          [ 1.2207e-01, -8.2812e-01,  2.0508e-01,  ..., -1.2344e+00,\n",
       "           -1.0859e+00, -1.0000e+00],\n",
       "          [-3.3203e-02, -1.5000e+00,  8.7891e-01,  ...,  1.1797e+00,\n",
       "           -1.4922e+00, -3.3936e-02],\n",
       "          [-5.1562e-01, -7.5781e-01, -2.0215e-01,  ..., -1.4844e+00,\n",
       "           -1.6504e-01,  3.7500e-01]],\n",
       "\n",
       "         [[-1.1215e-03,  2.2827e-02, -7.2937e-03,  ...,  3.3203e-02,\n",
       "            1.4531e+00, -5.7129e-02],\n",
       "          [ 1.9375e+00,  1.2031e+00, -1.2578e+00,  ..., -6.6016e-01,\n",
       "           -3.5156e+00, -3.1641e-01],\n",
       "          [ 2.4219e-01,  4.7070e-01, -3.5547e-01,  ...,  2.8711e-01,\n",
       "           -4.0312e+00,  8.3594e-01],\n",
       "          ...,\n",
       "          [-1.3770e-01,  6.6406e-01, -2.0215e-01,  ...,  7.4219e-01,\n",
       "           -4.5312e+00,  6.9141e-01],\n",
       "          [ 3.9258e-01, -5.3516e-01, -5.5469e-01,  ..., -3.0859e-01,\n",
       "           -4.2812e+00, -4.4141e-01],\n",
       "          [ 5.5859e-01,  8.5156e-01, -1.4062e+00,  ...,  9.5703e-01,\n",
       "           -4.7500e+00,  1.9336e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 7.1716e-03,  3.9795e-02,  1.0742e-02,  ...,  4.0625e-01,\n",
       "           -1.4160e-01, -3.6328e-01],\n",
       "          [-2.5391e-01, -5.8594e-01,  1.4551e-01,  ...,  5.5469e-01,\n",
       "           -3.4180e-01,  2.2168e-01],\n",
       "          [-6.3281e-01, -6.1719e-01,  5.1172e-01,  ..., -7.3438e-01,\n",
       "            9.2578e-01,  1.6328e+00],\n",
       "          ...,\n",
       "          [ 1.2422e+00, -3.3203e-02,  9.5312e-01,  ..., -4.2969e-01,\n",
       "           -9.1309e-02,  2.5000e+00],\n",
       "          [ 7.5684e-02, -2.2656e-01, -2.4023e-01,  ..., -1.4453e-01,\n",
       "           -3.3281e+00,  1.5625e+00],\n",
       "          [-5.6641e-01,  3.5742e-01, -1.7578e-01,  ...,  2.7148e-01,\n",
       "            3.2812e-01,  1.7500e+00]],\n",
       "\n",
       "         [[-1.0315e-02,  3.1738e-02,  1.8311e-02,  ...,  2.1250e+00,\n",
       "            2.2754e-01,  5.1562e-01],\n",
       "          [ 2.2461e-01, -5.3125e-01, -1.1562e+00,  ..., -3.2031e+00,\n",
       "            4.1992e-01, -4.9688e+00],\n",
       "          [ 2.2461e-02, -4.2188e-01, -4.3750e-01,  ..., -4.6562e+00,\n",
       "           -7.4609e-01, -4.6250e+00],\n",
       "          ...,\n",
       "          [-3.6328e-01,  4.6484e-01,  3.5156e-02,  ..., -4.5938e+00,\n",
       "           -8.7891e-01, -1.9238e-01],\n",
       "          [ 2.1973e-01,  4.6680e-01,  1.0312e+00,  ..., -5.0938e+00,\n",
       "            5.7422e-01, -4.6875e-01],\n",
       "          [ 4.9805e-01,  4.1016e-02,  4.4141e-01,  ..., -3.5000e+00,\n",
       "            1.1016e+00, -1.9922e+00]],\n",
       "\n",
       "         [[ 3.7842e-02,  5.4199e-02, -3.0518e-02,  ..., -1.1641e+00,\n",
       "           -2.0508e-01,  3.1055e-01],\n",
       "          [-1.8262e-01, -5.0781e-01, -1.6113e-01,  ...,  1.9297e+00,\n",
       "           -2.3828e-01,  3.3984e-01],\n",
       "          [-4.2188e-01, -4.1992e-01,  1.9336e-01,  ...,  2.3750e+00,\n",
       "            1.0781e+00,  7.4219e-01],\n",
       "          ...,\n",
       "          [-1.8652e-01, -1.1035e-01, -7.5684e-02,  ...,  2.5156e+00,\n",
       "           -3.5938e-01, -1.1797e+00],\n",
       "          [ 6.6895e-02, -2.6562e-01, -5.4199e-02,  ...,  2.5156e+00,\n",
       "            4.1211e-01, -6.0938e-01],\n",
       "          [ 2.6953e-01,  2.9102e-01, -3.9062e-03,  ...,  2.5938e+00,\n",
       "            8.9844e-01,  4.7656e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 8.9722e-03,  3.4790e-03, -1.5625e-02,  ...,  8.0566e-03,\n",
       "            6.1646e-03, -6.3171e-03],\n",
       "          [ 1.9727e-01, -8.3008e-02, -2.0117e-01,  ..., -9.5312e-01,\n",
       "            1.0312e+00,  6.8054e-03],\n",
       "          [ 4.8584e-02,  4.8584e-02, -1.3770e-01,  ..., -4.3164e-01,\n",
       "            4.2188e-01,  3.3594e-01],\n",
       "          ...,\n",
       "          [-1.2734e+00, -3.8672e-01,  3.4766e-01,  ...,  5.4688e-01,\n",
       "            3.1128e-02, -2.0703e-01],\n",
       "          [-8.7500e-01,  2.0508e-01,  8.5547e-01,  ...,  3.0762e-02,\n",
       "            8.4961e-02, -5.1953e-01],\n",
       "          [-2.3438e-01, -6.7444e-03,  5.7031e-01,  ...,  8.7891e-02,\n",
       "           -7.1094e-01, -1.0693e-01]],\n",
       "\n",
       "         [[-1.0071e-02,  2.9907e-02,  5.8899e-03,  ..., -3.7384e-03,\n",
       "           -2.3956e-03, -1.4832e-02],\n",
       "          [ 1.2061e-01, -4.0430e-01,  1.6602e-01,  ..., -4.1602e-01,\n",
       "            4.6875e-01, -3.4766e-01],\n",
       "          [ 4.5898e-01, -3.8672e-01,  5.3125e-01,  ..., -4.2578e-01,\n",
       "            6.3672e-01, -1.1719e-01],\n",
       "          ...,\n",
       "          [-6.5234e-01,  2.0801e-01,  2.5586e-01,  ...,  8.4766e-01,\n",
       "           -4.1602e-01, -7.1777e-02],\n",
       "          [-2.8442e-02,  3.1641e-01, -6.0730e-03,  ...,  8.6719e-01,\n",
       "           -6.2109e-01,  3.1836e-01],\n",
       "          [ 1.9434e-01,  2.7930e-01,  2.3340e-01,  ..., -2.7734e-01,\n",
       "           -2.0215e-01,  7.8906e-01]],\n",
       "\n",
       "         [[ 1.0803e-02,  1.7334e-02, -1.6235e-02,  ...,  1.3611e-02,\n",
       "            9.3994e-03,  2.7588e-02],\n",
       "          [-7.0801e-02,  4.8828e-01, -2.4805e-01,  ...,  9.5312e-01,\n",
       "            4.3213e-02, -7.4219e-01],\n",
       "          [-1.7480e-01,  5.7031e-01,  4.5312e-01,  ...,  9.4531e-01,\n",
       "           -5.7129e-02, -6.3672e-01],\n",
       "          ...,\n",
       "          [-1.3516e+00,  3.2031e-01,  1.0625e+00,  ...,  1.3906e+00,\n",
       "            1.1523e-01,  1.2793e-01],\n",
       "          [-1.1094e+00,  2.2363e-01,  7.8125e-01,  ...,  8.2422e-01,\n",
       "           -1.2500e-01,  4.4434e-02],\n",
       "          [-9.5703e-01, -5.5859e-01,  1.7812e+00,  ...,  1.9375e+00,\n",
       "            2.1875e-01, -6.6223e-03]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.6001e-02, -1.1475e-02, -7.4158e-03,  ..., -1.0559e-02,\n",
       "           -2.0752e-03,  9.6436e-03],\n",
       "          [-8.9062e-01, -3.5156e-02, -1.8555e-01,  ...,  6.3672e-01,\n",
       "           -8.9453e-01,  1.0889e-01],\n",
       "          [-1.1953e+00, -5.2734e-01, -3.9648e-01,  ...,  6.5625e-01,\n",
       "           -3.6328e-01,  2.3340e-01],\n",
       "          ...,\n",
       "          [ 6.8750e-01,  3.7891e-01,  4.8438e-01,  ...,  7.9688e-01,\n",
       "            5.4688e-01,  2.8711e-01],\n",
       "          [ 5.8203e-01,  1.2656e+00,  1.9141e+00,  ...,  2.3193e-02,\n",
       "            2.6245e-02, -5.9766e-01],\n",
       "          [-6.4062e-01, -6.1328e-01, -1.4258e-01,  ...,  6.6895e-02,\n",
       "            8.9355e-02,  1.1816e-01]],\n",
       "\n",
       "         [[-9.2697e-04, -3.5400e-03,  6.2561e-03,  ..., -3.0640e-02,\n",
       "            1.0742e-02, -8.6594e-04],\n",
       "          [ 2.8198e-02, -4.0430e-01, -8.1250e-01,  ...,  2.1973e-01,\n",
       "           -4.8584e-02, -1.0693e-01],\n",
       "          [ 4.7119e-02, -3.0664e-01, -8.6719e-01,  ...,  2.2168e-01,\n",
       "            1.8066e-01,  1.0254e-01],\n",
       "          ...,\n",
       "          [ 2.4512e-01, -3.5352e-01,  5.3906e-01,  ...,  2.5391e-01,\n",
       "           -1.4551e-01,  3.1250e-02],\n",
       "          [ 1.1797e+00, -7.5781e-01,  7.1875e-01,  ..., -2.3340e-01,\n",
       "            4.6094e-01,  1.0547e+00],\n",
       "          [ 1.6968e-02, -1.2422e+00,  2.7148e-01,  ..., -6.6016e-01,\n",
       "            2.0801e-01,  1.7969e-01]],\n",
       "\n",
       "         [[ 8.1177e-03,  7.1716e-03,  1.1353e-02,  ...,  1.1963e-02,\n",
       "            1.0925e-02, -1.2024e-02],\n",
       "          [ 2.0605e-01, -1.3594e+00, -3.9844e-01,  ..., -1.3086e-01,\n",
       "            5.3125e-01, -2.5977e-01],\n",
       "          [ 7.7344e-01, -7.5781e-01, -2.2949e-01,  ..., -7.9297e-01,\n",
       "           -4.1406e-01, -2.4316e-01],\n",
       "          ...,\n",
       "          [ 2.7539e-01, -3.3264e-03, -3.3594e-01,  ..., -3.2031e-01,\n",
       "            6.8359e-03,  5.3125e-01],\n",
       "          [-2.2754e-01, -1.0010e-01, -4.7852e-01,  ..., -5.3516e-01,\n",
       "           -2.6172e-01,  6.5625e-01],\n",
       "          [ 7.5195e-02, -9.6484e-01, -3.3008e-01,  ..., -3.9258e-01,\n",
       "            3.9258e-01, -3.4180e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.9673e-03,  6.9580e-03,  1.5198e-02,  ..., -1.8555e-01,\n",
       "           -1.9684e-03, -7.7734e-01],\n",
       "          [-5.0781e-01,  1.2451e-01,  4.3359e-01,  ...,  7.9688e-01,\n",
       "            1.3359e+00,  2.5156e+00],\n",
       "          [ 1.9434e-01,  3.7500e-01,  4.0234e-01,  ..., -1.9336e-01,\n",
       "            1.8047e+00,  2.7031e+00],\n",
       "          ...,\n",
       "          [ 2.3242e-01,  1.0986e-01, -4.3945e-01,  ...,  3.3789e-01,\n",
       "            8.9062e-01,  1.6250e+00],\n",
       "          [-7.8125e-03,  5.2344e-01, -1.1865e-01,  ...,  2.0781e+00,\n",
       "            4.5898e-01,  9.9609e-01],\n",
       "          [-2.6953e-01,  5.3125e-01, -4.6094e-01,  ...,  3.2227e-01,\n",
       "            1.3906e+00,  2.4707e-01]],\n",
       "\n",
       "         [[-1.6479e-02, -2.6978e-02,  1.9775e-02,  ...,  3.1836e-01,\n",
       "            4.3945e-01,  3.3008e-01],\n",
       "          [-5.9375e-01, -1.4453e-01,  2.3047e-01,  ...,  2.9492e-01,\n",
       "            1.0449e-01, -6.1719e-01],\n",
       "          [ 2.3633e-01,  1.6602e-02, -6.5918e-02,  ...,  3.8672e-01,\n",
       "            4.5117e-01, -1.0156e+00],\n",
       "          ...,\n",
       "          [-1.8164e-01, -1.5625e-02,  3.2227e-02,  ...,  6.2891e-01,\n",
       "           -3.7695e-01, -7.6172e-01],\n",
       "          [-1.3984e+00,  4.7266e-01,  5.8984e-01,  ...,  3.3789e-01,\n",
       "           -5.3516e-01,  1.9141e-01],\n",
       "          [-1.3125e+00,  8.6328e-01,  1.2988e-01,  ..., -8.8672e-01,\n",
       "           -8.1641e-01, -2.6953e-01]],\n",
       "\n",
       "         [[-5.4016e-03, -1.9836e-03, -3.9978e-03,  ..., -3.1641e-01,\n",
       "            5.5176e-02,  1.9434e-01],\n",
       "          [-2.4512e-01, -4.9609e-01, -7.7734e-01,  ...,  3.7891e-01,\n",
       "            9.2188e-01, -2.7466e-02],\n",
       "          [ 3.6328e-01, -5.5078e-01,  3.5938e-01,  ...,  2.8906e-01,\n",
       "           -1.4160e-01,  6.3672e-01],\n",
       "          ...,\n",
       "          [-1.4648e-02,  2.6562e-01,  5.4688e-01,  ..., -1.4141e+00,\n",
       "           -1.2734e+00, -8.7109e-01],\n",
       "          [ 5.8203e-01,  2.9297e-02, -5.1953e-01,  ...,  8.7891e-01,\n",
       "            6.1035e-02, -7.6953e-01],\n",
       "          [-4.1504e-02,  5.4297e-01, -8.2031e-01,  ...,  9.2188e-01,\n",
       "           -9.9219e-01, -3.9258e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-3.5400e-02,  2.9907e-03, -4.9591e-04,  ...,  7.6172e-02,\n",
       "            2.6172e-01, -2.8809e-02],\n",
       "          [-1.1719e-02, -5.8203e-01,  8.0859e-01,  ..., -9.1406e-01,\n",
       "           -3.8574e-02, -2.2344e+00],\n",
       "          [-2.9492e-01, -9.7656e-01,  1.0547e+00,  ..., -2.0605e-01,\n",
       "           -3.5938e-01, -2.7969e+00],\n",
       "          ...,\n",
       "          [-6.4844e-01, -2.0898e-01, -1.4160e-01,  ..., -4.0234e-01,\n",
       "            9.4727e-02, -1.2188e+00],\n",
       "          [-4.3945e-01, -1.1250e+00, -2.0312e+00,  ..., -1.4609e+00,\n",
       "            1.5312e+00,  5.9766e-01],\n",
       "          [ 1.2402e-01,  8.6328e-01, -5.2344e-01,  ..., -4.0430e-01,\n",
       "            4.0820e-01, -1.3906e+00]],\n",
       "\n",
       "         [[-2.4658e-02, -9.0332e-03,  3.1128e-02,  ..., -1.9824e-01,\n",
       "            5.1953e-01,  2.8809e-02],\n",
       "          [ 1.0352e-01, -3.9062e-01, -5.3906e-01,  ...,  1.2656e+00,\n",
       "           -5.5078e-01, -2.5781e+00],\n",
       "          [ 3.7891e-01, -6.0938e-01, -6.2109e-01,  ...,  1.4453e+00,\n",
       "           -3.7500e-01, -1.1562e+00],\n",
       "          ...,\n",
       "          [-7.3853e-03,  3.4766e-01, -1.3867e-01,  ...,  7.6562e-01,\n",
       "           -6.6406e-01, -7.2656e-01],\n",
       "          [-1.8555e-02,  3.3594e-01, -6.2012e-02,  ...,  1.4844e+00,\n",
       "           -1.7500e+00,  4.0039e-01],\n",
       "          [-4.2969e-02,  4.0039e-01,  3.4766e-01,  ...,  1.7969e+00,\n",
       "           -1.5625e-01, -9.6875e-01]],\n",
       "\n",
       "         [[ 9.4604e-03, -4.1504e-03,  3.2959e-02,  ...,  3.4766e-01,\n",
       "            1.0469e+00,  4.3164e-01],\n",
       "          [-8.4961e-02, -3.7109e-02, -2.8076e-02,  ..., -1.2891e+00,\n",
       "           -1.9922e+00,  4.3164e-01],\n",
       "          [-1.6797e-01, -1.4771e-02, -1.3477e-01,  ..., -9.0234e-01,\n",
       "           -2.4531e+00, -9.1406e-01],\n",
       "          ...,\n",
       "          [-4.3750e-01,  3.7695e-01, -1.6797e-01,  ..., -3.2031e-01,\n",
       "           -1.8125e+00, -2.2656e+00],\n",
       "          [-1.5430e-01, -3.7305e-01, -1.9824e-01,  ..., -8.0078e-01,\n",
       "            2.9785e-02, -1.6953e+00],\n",
       "          [ 3.7891e-01, -1.5430e-01, -2.0703e-01,  ...,  6.6016e-01,\n",
       "            4.5508e-01, -7.6562e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-7.8125e-03, -4.6692e-03, -4.5013e-04,  ..., -6.9885e-03,\n",
       "           -8.5449e-03, -8.4839e-03],\n",
       "          [-3.3984e-01, -2.0020e-01,  2.7148e-01,  ...,  4.5898e-02,\n",
       "            8.7109e-01,  5.1562e-01],\n",
       "          [-1.8652e-01,  6.1279e-02,  2.6611e-02,  ...,  2.6562e-01,\n",
       "            1.4453e+00,  6.4844e-01],\n",
       "          ...,\n",
       "          [ 1.2305e-01, -6.4453e-01,  3.7305e-01,  ..., -1.1250e+00,\n",
       "            1.9043e-01, -9.5703e-01],\n",
       "          [-6.0547e-01, -4.5898e-01, -5.6641e-01,  ..., -1.5469e+00,\n",
       "           -3.9795e-02, -3.7891e-01],\n",
       "          [-8.2520e-02,  2.5757e-02,  2.6978e-02,  ..., -7.6172e-01,\n",
       "           -1.1035e-01,  1.9165e-02]],\n",
       "\n",
       "         [[ 1.8311e-02,  2.8992e-03,  1.8066e-02,  ..., -5.8594e-03,\n",
       "            4.4250e-03, -1.2894e-03],\n",
       "          [ 7.6953e-01, -8.1250e-01,  8.8867e-02,  ..., -1.0205e-01,\n",
       "           -3.1250e-01, -2.5781e-01],\n",
       "          [ 8.9062e-01, -1.3047e+00, -2.5977e-01,  ...,  2.4121e-01,\n",
       "           -5.2490e-02, -2.3730e-01],\n",
       "          ...,\n",
       "          [-5.1953e-01,  2.0020e-01, -2.2461e-01,  ..., -2.8906e-01,\n",
       "           -1.0986e-01, -5.9128e-04],\n",
       "          [-2.8906e-01,  5.2344e-01,  6.2988e-02,  ..., -4.4922e-01,\n",
       "           -2.9883e-01, -1.0400e-01],\n",
       "          [ 1.0205e-01, -6.4844e-01, -3.3203e-01,  ..., -5.7422e-01,\n",
       "           -4.4922e-01, -5.0781e-01]],\n",
       "\n",
       "         [[ 3.6001e-05, -1.6556e-03, -3.1128e-03,  ...,  2.0630e-02,\n",
       "           -1.0925e-02,  8.6670e-03],\n",
       "          [ 1.5039e-01, -7.6953e-01, -5.0781e-01,  ..., -3.6914e-01,\n",
       "            2.9883e-01,  6.0547e-01],\n",
       "          [ 4.7070e-01, -3.2422e-01, -9.8145e-02,  ...,  5.9766e-01,\n",
       "            4.6289e-01, -1.4453e-01],\n",
       "          ...,\n",
       "          [-6.4844e-01,  5.3516e-01,  9.1016e-01,  ..., -4.0039e-02,\n",
       "           -4.5166e-02, -8.3203e-01],\n",
       "          [-1.7395e-03, -6.2500e-01,  9.8438e-01,  ..., -5.4297e-01,\n",
       "           -3.1836e-01, -3.4570e-01],\n",
       "          [ 3.7500e-01, -1.7578e-01, -7.6172e-01,  ...,  3.3398e-01,\n",
       "            6.0156e-01, -6.4062e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.2451e-02, -6.0120e-03, -1.5640e-03,  ..., -5.1575e-03,\n",
       "            1.5015e-02,  6.8054e-03],\n",
       "          [ 1.3984e+00,  1.5991e-02,  2.5781e-01,  ..., -4.5654e-02,\n",
       "           -7.9590e-02, -1.0156e-01],\n",
       "          [ 1.4219e+00,  7.3438e-01,  1.3867e-01,  ...,  5.0391e-01,\n",
       "            1.6724e-02, -6.9922e-01],\n",
       "          ...,\n",
       "          [ 8.0859e-01,  4.3164e-01, -1.6699e-01,  ...,  9.3750e-01,\n",
       "            1.7773e-01, -1.2344e+00],\n",
       "          [ 8.3203e-01, -4.6875e-02, -5.1562e-01,  ...,  3.3594e-01,\n",
       "            1.5625e-01, -9.6484e-01],\n",
       "          [-5.1953e-01,  6.9580e-03, -1.3984e+00,  ..., -3.9844e-01,\n",
       "           -2.6953e-01, -1.3125e+00]],\n",
       "\n",
       "         [[ 7.0953e-04, -1.5747e-02, -4.2725e-03,  ..., -4.5776e-03,\n",
       "           -3.7689e-03, -2.4414e-03],\n",
       "          [-1.4160e-01, -5.3516e-01, -5.3711e-02,  ...,  5.7031e-01,\n",
       "            8.2812e-01,  4.4922e-01],\n",
       "          [-4.1016e-02, -7.1484e-01,  2.9297e-02,  ...,  1.4648e-01,\n",
       "            5.5859e-01,  1.0254e-01],\n",
       "          ...,\n",
       "          [-2.3730e-01, -2.5391e-01, -5.3906e-01,  ...,  2.8442e-02,\n",
       "           -4.3945e-01, -3.0078e-01],\n",
       "          [-4.9805e-01, -8.0469e-01,  6.3965e-02,  ...,  1.2500e+00,\n",
       "           -1.1816e-01,  1.1426e-01],\n",
       "          [ 2.8711e-01, -4.1406e-01,  4.6875e-02,  ..., -3.5156e-01,\n",
       "           -2.0801e-01,  6.3281e-01]],\n",
       "\n",
       "         [[-8.3008e-03,  1.7700e-02, -1.0742e-02,  ...,  1.6022e-03,\n",
       "           -1.6235e-02, -5.1575e-03],\n",
       "          [ 4.2969e-01,  1.8359e-01, -3.5938e-01,  ...,  6.9531e-01,\n",
       "           -1.5332e-01,  4.0234e-01],\n",
       "          [ 6.9141e-01, -1.8945e-01, -3.8867e-01,  ...,  9.2773e-02,\n",
       "            1.6895e-01,  4.5898e-02],\n",
       "          ...,\n",
       "          [ 4.2383e-01,  1.9434e-01, -3.1055e-01,  ...,  2.1680e-01,\n",
       "            1.3281e-01, -4.7607e-02],\n",
       "          [ 4.3164e-01, -6.9922e-01,  4.7266e-01,  ...,  4.3164e-01,\n",
       "           -3.5547e-01, -2.4780e-02],\n",
       "          [-5.6396e-02,  3.1445e-01,  2.7734e-01,  ...,  1.0193e-02,\n",
       "           -3.1641e-01,  2.2949e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-2.3556e-04,  1.9653e-02,  2.3682e-02,  ...,  1.9653e-02,\n",
       "           -6.5613e-03, -1.2158e-01],\n",
       "          [ 1.3516e+00, -3.3984e-01, -8.9062e-01,  ...,  2.5391e-01,\n",
       "            2.8906e-01,  9.6484e-01],\n",
       "          [ 6.5625e-01, -2.7930e-01, -1.2500e+00,  ...,  1.1182e-01,\n",
       "            3.7109e-01,  7.7344e-01],\n",
       "          ...,\n",
       "          [ 6.7188e-01, -3.7109e-01,  1.3906e+00,  ..., -4.3945e-01,\n",
       "           -4.3555e-01, -3.9062e-02],\n",
       "          [ 7.2656e-01, -1.0625e+00,  1.4609e+00,  ...,  1.0312e+00,\n",
       "           -7.1094e-01, -2.7539e-01],\n",
       "          [ 8.1641e-01,  1.2451e-01, -1.6211e-01,  ..., -1.1094e+00,\n",
       "            4.0625e-01, -1.1016e+00]],\n",
       "\n",
       "         [[ 2.7771e-03, -8.0566e-03,  2.1851e-02,  ..., -2.2461e-01,\n",
       "           -2.0605e-01, -6.5234e-01],\n",
       "          [-9.6875e-01, -1.0938e+00,  4.8828e-02,  ...,  1.1406e+00,\n",
       "           -1.7031e+00,  1.5391e+00],\n",
       "          [ 7.1484e-01, -1.0156e+00,  1.5259e-02,  ...,  9.7266e-01,\n",
       "           -9.3750e-01,  1.6562e+00],\n",
       "          ...,\n",
       "          [ 5.6250e-01, -3.2422e-01, -9.3750e-01,  ...,  1.0312e+00,\n",
       "           -5.8594e-01, -4.0820e-01],\n",
       "          [-3.5938e-01, -1.7285e-01, -2.6758e-01,  ...,  7.2656e-01,\n",
       "           -7.3438e-01, -9.3750e-01],\n",
       "          [-2.8711e-01,  5.3516e-01, -3.4375e-01,  ...,  9.8828e-01,\n",
       "           -1.0938e+00,  2.9053e-02]],\n",
       "\n",
       "         [[-1.1597e-02,  1.5488e-03,  3.5156e-02,  ..., -3.6328e-01,\n",
       "            4.4678e-02,  1.6113e-01],\n",
       "          [-2.6172e-01,  8.5938e-02, -2.2559e-01,  ..., -8.3203e-01,\n",
       "           -1.8203e+00, -4.2578e-01],\n",
       "          [-4.5508e-01,  9.0234e-01,  6.3477e-02,  ...,  3.9844e-01,\n",
       "           -2.0625e+00, -2.5391e-01],\n",
       "          ...,\n",
       "          [ 5.7031e-01,  9.8633e-02,  2.3828e-01,  ...,  1.7812e+00,\n",
       "           -1.2578e+00, -3.5547e-01],\n",
       "          [ 5.2734e-01,  8.8672e-01,  3.0859e-01,  ...,  1.2031e+00,\n",
       "           -1.5234e-01, -8.3203e-01],\n",
       "          [-4.2188e-01, -5.3516e-01, -1.0547e-01,  ...,  1.8125e+00,\n",
       "           -4.2773e-01,  3.6133e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.8921e-02,  2.1240e-02,  2.1362e-02,  ...,  1.9531e+00,\n",
       "            2.2461e-01, -1.0625e+00],\n",
       "          [-1.4551e-01, -2.9883e-01, -2.0703e-01,  ..., -3.8125e+00,\n",
       "            2.1094e+00,  1.7812e+00],\n",
       "          [-5.8984e-01, -6.7188e-01, -9.2188e-01,  ..., -4.1562e+00,\n",
       "            2.1719e+00,  1.3125e+00],\n",
       "          ...,\n",
       "          [ 4.6680e-01,  4.5703e-01,  3.8672e-01,  ..., -3.9688e+00,\n",
       "            7.9688e-01, -1.1094e+00],\n",
       "          [ 2.3340e-01, -1.6797e-01,  2.2559e-01,  ..., -6.8750e-01,\n",
       "            2.7500e+00,  8.3984e-01],\n",
       "          [ 2.7734e-01,  5.7031e-01,  5.0781e-01,  ..., -5.3750e+00,\n",
       "            2.2812e+00,  5.2246e-02]],\n",
       "\n",
       "         [[-2.7100e-02, -7.4158e-03, -2.2461e-02,  ..., -6.9141e-01,\n",
       "            9.0820e-02,  4.0039e-01],\n",
       "          [-1.3281e-01, -4.4531e-01,  2.8711e-01,  ...,  1.6016e+00,\n",
       "           -1.2734e+00, -8.6328e-01],\n",
       "          [-4.7266e-01, -3.1641e-01,  6.2109e-01,  ...,  2.1562e+00,\n",
       "           -3.7891e-01, -3.4766e-01],\n",
       "          ...,\n",
       "          [ 3.2617e-01, -8.7891e-01, -9.0234e-01,  ..., -3.9648e-01,\n",
       "           -2.0469e+00, -1.1797e+00],\n",
       "          [-2.3438e-01, -1.1719e-01, -1.8457e-01,  ..., -2.9883e-01,\n",
       "           -2.3438e+00, -7.3828e-01],\n",
       "          [-1.6406e-01, -3.9062e-02,  4.4922e-02,  ..., -6.9141e-01,\n",
       "           -7.2656e-01,  1.2578e+00]],\n",
       "\n",
       "         [[-2.4292e-02,  4.6997e-03, -7.7820e-03,  ...,  1.4160e-01,\n",
       "           -8.9844e-02,  7.4219e-02],\n",
       "          [-4.6484e-01,  4.0234e-01, -1.1406e+00,  ...,  4.6484e-01,\n",
       "            8.9844e-02,  1.5000e+00],\n",
       "          [ 6.4453e-02, -4.8242e-01, -8.5938e-01,  ...,  3.8477e-01,\n",
       "            5.5469e-01,  1.6875e+00],\n",
       "          ...,\n",
       "          [-3.5156e-01, -9.7656e-01, -1.1094e+00,  ...,  2.6953e-01,\n",
       "           -1.7344e+00,  1.5781e+00],\n",
       "          [-1.6641e+00, -6.8750e-01, -3.4375e-01,  ..., -9.2578e-01,\n",
       "           -1.0391e+00,  1.1875e+00],\n",
       "          [-8.8672e-01,  5.7422e-01,  1.1953e+00,  ...,  2.7930e-01,\n",
       "           -2.4902e-01,  4.9609e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-5.2185e-03,  2.2583e-02, -1.8692e-03,  ..., -1.6785e-03,\n",
       "           -9.8419e-04,  9.5825e-03],\n",
       "          [-5.4688e-01, -1.6846e-02,  6.8359e-02,  ..., -6.2500e-01,\n",
       "            1.4941e-01,  3.6523e-01],\n",
       "          [-6.5234e-01, -1.4258e-01,  5.9570e-02,  ..., -3.7891e-01,\n",
       "            2.9688e-01,  2.5195e-01],\n",
       "          ...,\n",
       "          [-1.8066e-01,  2.1484e-01, -4.1797e-01,  ..., -9.0234e-01,\n",
       "           -5.5859e-01,  2.8516e-01],\n",
       "          [-1.3379e-01, -1.2012e-01, -3.7500e-01,  ..., -8.6328e-01,\n",
       "           -9.8828e-01, -3.0078e-01],\n",
       "          [ 3.7305e-01, -2.1875e-01,  4.4678e-02,  ..., -1.2988e-01,\n",
       "           -5.8594e-01, -5.3125e-01]],\n",
       "\n",
       "         [[-2.7832e-02, -6.5613e-03, -5.6458e-03,  ...,  1.9897e-02,\n",
       "            9.7275e-04, -1.3580e-03],\n",
       "          [-2.5391e-01, -4.5898e-01, -2.1777e-01,  ..., -8.2422e-01,\n",
       "            1.8848e-01,  2.2754e-01],\n",
       "          [-4.0527e-02, -8.5938e-01, -2.0605e-01,  ..., -7.2266e-01,\n",
       "            2.9883e-01, -9.4727e-02],\n",
       "          ...,\n",
       "          [-2.4512e-01, -2.6367e-01,  2.9297e-01,  ...,  2.3633e-01,\n",
       "            1.0645e-01, -1.0010e-01],\n",
       "          [ 4.4531e-01,  2.1582e-01,  2.5977e-01,  ...,  4.1992e-01,\n",
       "            1.6211e-01, -4.5898e-01],\n",
       "          [-4.4922e-01, -2.7148e-01,  3.5547e-01,  ...,  1.6113e-01,\n",
       "            3.3789e-01,  8.5938e-01]],\n",
       "\n",
       "         [[-2.7008e-03,  1.0071e-02, -1.3580e-03,  ..., -1.4877e-04,\n",
       "           -2.0905e-03,  1.1597e-02],\n",
       "          [-8.7891e-03, -1.1328e+00, -3.2617e-01,  ..., -3.8086e-01,\n",
       "            5.6250e-01,  1.6113e-01],\n",
       "          [-6.8359e-02, -1.0312e+00, -3.8086e-01,  ...,  2.6562e-01,\n",
       "            4.9023e-01,  9.0820e-02],\n",
       "          ...,\n",
       "          [-3.2715e-02,  2.1094e-01, -1.1914e-01,  ...,  5.5859e-01,\n",
       "           -9.1406e-01,  3.9648e-01],\n",
       "          [ 6.3672e-01,  5.5859e-01, -1.5234e-01,  ..., -1.2329e-02,\n",
       "            4.8438e-01, -1.5039e-01],\n",
       "          [-4.8340e-02,  4.1992e-01, -5.3125e-01,  ..., -7.1777e-02,\n",
       "            1.6016e-01,  4.0625e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.2085e-02,  9.3994e-03, -8.6670e-03,  ..., -2.5146e-02,\n",
       "           -1.0071e-03,  5.2185e-03],\n",
       "          [ 3.2227e-01, -5.3516e-01, -5.6641e-01,  ...,  8.7109e-01,\n",
       "            5.5078e-01, -2.2852e-01],\n",
       "          [ 1.2598e-01,  3.2031e-01, -4.5410e-02,  ...,  8.6328e-01,\n",
       "            3.8672e-01,  2.4121e-01],\n",
       "          ...,\n",
       "          [ 4.1504e-02,  3.2227e-01,  1.1768e-01,  ...,  8.6719e-01,\n",
       "           -6.0547e-01, -1.7578e-02],\n",
       "          [ 1.2256e-01,  9.6484e-01, -7.8906e-01,  ...,  6.8359e-01,\n",
       "           -1.0889e-01,  6.4453e-01],\n",
       "          [ 3.5352e-01, -4.8828e-02, -2.4414e-01,  ...,  5.9766e-01,\n",
       "            3.0078e-01,  3.7305e-01]],\n",
       "\n",
       "         [[-1.4832e-02, -2.4109e-03,  3.0518e-03,  ..., -1.5717e-03,\n",
       "            6.5918e-03,  1.8433e-02],\n",
       "          [ 2.2363e-01, -1.2812e+00, -2.5586e-01,  ..., -4.8438e-01,\n",
       "            6.8750e-01, -5.2734e-01],\n",
       "          [ 3.7500e-01, -6.0156e-01,  3.1738e-02,  ..., -2.0508e-01,\n",
       "            6.8750e-01, -1.2207e-01],\n",
       "          ...,\n",
       "          [ 5.3906e-01,  5.0000e-01,  3.5547e-01,  ...,  4.4189e-02,\n",
       "           -5.8984e-01,  2.6562e-01],\n",
       "          [-1.5527e-01, -5.0781e-01, -7.4707e-02,  ..., -4.4531e-01,\n",
       "            9.1797e-02,  1.4355e-01],\n",
       "          [-1.9629e-01, -1.4160e-01, -1.0859e+00,  ..., -2.9883e-01,\n",
       "            4.6680e-01, -3.4570e-01]],\n",
       "\n",
       "         [[ 1.0437e-02, -2.5513e-02, -7.5378e-03,  ...,  6.6223e-03,\n",
       "           -1.9897e-02, -3.8605e-03],\n",
       "          [-2.1191e-01, -2.9297e-01,  4.7852e-01,  ...,  3.0664e-01,\n",
       "            8.9062e-01,  2.9492e-01],\n",
       "          [ 2.1875e-01, -1.5503e-02,  8.2812e-01,  ...,  8.6719e-01,\n",
       "            8.9453e-01, -3.3112e-03],\n",
       "          ...,\n",
       "          [ 1.0703e+00, -8.3008e-02, -5.8594e-01,  ..., -5.5078e-01,\n",
       "            5.5859e-01,  5.1953e-01],\n",
       "          [ 4.6875e-01, -5.6250e-01,  1.5625e+00,  ..., -7.0703e-01,\n",
       "            2.6367e-01,  2.0996e-01],\n",
       "          [ 8.2031e-01,  3.4961e-01,  7.1484e-01,  ...,  3.6914e-01,\n",
       "            6.4844e-01, -1.0391e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 3.3203e-02,  2.2430e-03, -5.4016e-03,  ...,  1.3672e-01,\n",
       "           -2.4902e-01, -1.2656e+00],\n",
       "          [-2.7734e-01,  5.9375e-01, -1.8457e-01,  ...,  7.3242e-02,\n",
       "            3.2969e+00,  8.0000e+00],\n",
       "          [-8.0859e-01,  3.9062e-01, -2.7344e-01,  ...,  7.7344e-01,\n",
       "            3.1562e+00,  7.5625e+00],\n",
       "          ...,\n",
       "          [ 8.5547e-01, -3.1641e-01,  3.8477e-01,  ...,  3.0469e-01,\n",
       "            3.9844e+00,  5.7500e+00],\n",
       "          [ 4.8828e-01, -1.2305e-01,  2.8125e-01,  ..., -7.6953e-01,\n",
       "            2.9688e+00,  4.5625e+00],\n",
       "          [ 2.7344e-01,  5.9570e-02, -1.0742e-01,  ..., -1.3984e+00,\n",
       "            2.9688e+00,  4.4688e+00]],\n",
       "\n",
       "         [[-2.6001e-02,  3.4424e-02,  6.8665e-03,  ..., -2.1582e-01,\n",
       "            4.3213e-02,  3.2617e-01],\n",
       "          [-1.8359e-01, -2.8711e-01,  6.7383e-02,  ...,  1.7676e-01,\n",
       "            4.6875e-01,  2.5781e+00],\n",
       "          [-6.8359e-01,  5.9766e-01,  1.8164e-01,  ...,  2.2656e-01,\n",
       "            2.3145e-01,  1.7422e+00],\n",
       "          ...,\n",
       "          [-6.5625e-01,  3.4961e-01,  1.2695e-01,  ...,  1.2500e+00,\n",
       "            2.8711e-01, -2.5781e-01],\n",
       "          [ 3.0396e-02, -4.8242e-01, -4.8828e-03,  ...,  2.3906e+00,\n",
       "           -2.0215e-01, -2.2266e-01],\n",
       "          [ 9.5312e-01, -2.4023e-01, -4.2480e-02,  ..., -1.8359e-01,\n",
       "            9.9609e-01,  1.6484e+00]],\n",
       "\n",
       "         [[ 7.2327e-03, -1.0529e-03,  1.2207e-02,  ..., -9.0790e-04,\n",
       "           -5.2490e-02, -8.6914e-02],\n",
       "          [ 3.6523e-01, -1.5859e+00,  1.6895e-01,  ..., -8.4229e-03,\n",
       "            1.4648e-01,  1.1953e+00],\n",
       "          [-1.2012e-01, -1.0312e+00,  6.7383e-02,  ..., -3.3984e-01,\n",
       "            2.6172e-01,  2.0625e+00],\n",
       "          ...,\n",
       "          [ 1.4609e+00, -2.3438e-01,  3.0469e-01,  ...,  4.7852e-01,\n",
       "            1.3359e+00,  9.6484e-01],\n",
       "          [ 4.2188e-01, -4.4531e-01, -1.0107e-01,  ...,  2.0625e+00,\n",
       "            1.0469e+00,  3.3398e-01],\n",
       "          [-8.9355e-02, -7.8125e-01, -8.2031e-01,  ...,  3.3594e-01,\n",
       "            9.8828e-01,  8.7891e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.8076e-02, -9.4223e-04,  1.1047e-02,  ..., -8.0078e-02,\n",
       "            1.3086e-01,  9.6191e-02],\n",
       "          [-6.6406e-01, -1.5000e+00,  3.2031e-01,  ..., -1.1250e+00,\n",
       "           -1.1797e+00,  6.7578e-01],\n",
       "          [ 4.5898e-01, -1.4297e+00,  9.6680e-02,  ..., -3.6914e-01,\n",
       "           -1.0625e+00,  9.0625e-01],\n",
       "          ...,\n",
       "          [ 5.4688e-01,  1.0625e+00, -8.0078e-01,  ..., -1.3047e+00,\n",
       "            1.5527e-01,  1.7285e-01],\n",
       "          [ 6.2500e-01,  5.1172e-01, -5.8203e-01,  ..., -7.5781e-01,\n",
       "            1.0547e+00, -6.7188e-01],\n",
       "          [-3.7109e-01, -6.8359e-01,  2.8906e-01,  ..., -1.9375e+00,\n",
       "           -2.0312e-01, -4.7852e-02]],\n",
       "\n",
       "         [[-5.8899e-03,  8.7891e-03, -1.0834e-03,  ...,  1.6309e-01,\n",
       "           -9.0332e-02, -1.0840e-01],\n",
       "          [-1.4531e+00,  9.8633e-02,  1.2354e-01,  ..., -1.4375e+00,\n",
       "            1.0938e+00, -2.4707e-01],\n",
       "          [-1.2891e+00, -4.6143e-02,  2.2754e-01,  ..., -1.4375e+00,\n",
       "            1.5859e+00,  5.2734e-01],\n",
       "          ...,\n",
       "          [ 1.1953e+00, -5.5078e-01,  1.3281e-01,  ..., -2.8711e-01,\n",
       "            6.0547e-01,  1.6406e+00],\n",
       "          [ 7.1484e-01, -1.2109e-01,  1.1641e+00,  ...,  2.2656e+00,\n",
       "           -1.0059e-01,  7.8613e-02],\n",
       "          [-1.0312e+00,  6.6406e-01,  8.1250e-01,  ...,  3.1445e-01,\n",
       "           -2.7222e-02,  1.1406e+00]],\n",
       "\n",
       "         [[ 1.9165e-02,  3.7354e-02, -2.7466e-02,  ..., -1.2024e-02,\n",
       "           -2.8711e-01,  2.0215e-01],\n",
       "          [ 3.9844e-01,  1.1328e+00,  2.7734e-01,  ...,  1.0254e-02,\n",
       "            1.5938e+00, -1.5469e+00],\n",
       "          [-2.3633e-01, -2.1289e-01, -2.6172e-01,  ...,  7.5781e-01,\n",
       "            1.1523e-01, -6.1328e-01],\n",
       "          ...,\n",
       "          [ 1.0889e-01, -1.9922e-01, -3.6719e-01,  ..., -4.0234e-01,\n",
       "           -5.1172e-01, -1.5000e+00],\n",
       "          [ 1.0078e+00, -4.7266e-01, -6.0938e-01,  ...,  8.9062e-01,\n",
       "           -5.7422e-01, -1.6641e+00],\n",
       "          [ 1.9434e-01,  7.4609e-01,  7.4219e-01,  ...,  1.2969e+00,\n",
       "           -9.2578e-01, -2.2969e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-1.0620e-02,  2.0752e-02,  5.3711e-03,  ..., -2.7100e-02,\n",
       "            9.7046e-03, -1.3733e-02],\n",
       "          [-4.2969e-02,  5.3516e-01,  2.0898e-01,  ..., -4.1797e-01,\n",
       "           -1.0547e+00,  2.5195e-01],\n",
       "          [ 4.7852e-01, -2.3633e-01, -2.5195e-01,  ..., -9.4238e-02,\n",
       "           -1.7383e-01,  3.6328e-01],\n",
       "          ...,\n",
       "          [-3.8477e-01, -4.1992e-01,  5.1953e-01,  ...,  1.1953e+00,\n",
       "           -5.9375e-01, -1.2146e-02],\n",
       "          [ 3.8477e-01,  5.0391e-01,  9.5312e-01,  ...,  4.1992e-01,\n",
       "           -1.0234e+00, -5.3516e-01],\n",
       "          [ 3.2422e-01, -2.1777e-01,  1.5625e-01,  ...,  6.0938e-01,\n",
       "           -4.7266e-01,  5.6641e-01]],\n",
       "\n",
       "         [[-1.6724e-02,  2.5787e-03,  1.7452e-04,  ...,  2.0874e-02,\n",
       "           -6.8359e-03,  7.7515e-03],\n",
       "          [ 4.7266e-01,  3.5742e-01, -1.7480e-01,  ...,  2.2754e-01,\n",
       "           -6.5234e-01, -4.2578e-01],\n",
       "          [ 5.5469e-01,  9.8047e-01, -2.8125e-01,  ..., -1.3281e-01,\n",
       "           -4.4141e-01, -1.0000e+00],\n",
       "          ...,\n",
       "          [ 1.6211e-01, -1.6992e-01,  1.6406e-01,  ..., -1.1875e+00,\n",
       "            6.6797e-01,  3.3398e-01],\n",
       "          [ 4.5898e-01, -5.3125e-01,  2.9688e-01,  ..., -9.9121e-02,\n",
       "           -2.7539e-01, -4.4141e-01],\n",
       "          [ 4.3945e-01, -1.7285e-01, -2.6172e-01,  ..., -6.7383e-02,\n",
       "            2.3633e-01,  4.9219e-01]],\n",
       "\n",
       "         [[ 2.0447e-03,  4.1199e-03,  5.5237e-03,  ...,  2.1851e-02,\n",
       "           -8.8501e-03,  1.2024e-02],\n",
       "          [ 1.6016e-01,  6.8750e-01,  3.4961e-01,  ..., -1.3379e-01,\n",
       "           -1.6211e-01,  8.7891e-01],\n",
       "          [ 6.0547e-01,  2.1484e-01,  4.8633e-01,  ..., -1.3672e-01,\n",
       "           -1.7285e-01, -1.5991e-02],\n",
       "          ...,\n",
       "          [ 1.3184e-01, -4.2188e-01, -4.9219e-01,  ..., -2.6953e-01,\n",
       "           -9.6680e-02, -1.1328e-01],\n",
       "          [ 4.4922e-01, -5.8984e-01,  6.3281e-01,  ...,  3.0859e-01,\n",
       "            6.2500e-01,  4.5654e-02],\n",
       "          [ 3.7891e-01, -5.9766e-01, -5.3125e-01,  ..., -7.1875e-01,\n",
       "            4.2578e-01, -9.2578e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.5889e-02,  1.7700e-02,  1.2451e-02,  ...,  1.8921e-02,\n",
       "           -1.2329e-02,  2.6367e-02],\n",
       "          [-2.4536e-02, -2.7344e-02,  2.9102e-01,  ...,  3.4912e-02,\n",
       "           -5.4297e-01, -1.8066e-02],\n",
       "          [ 5.1172e-01, -6.7969e-01, -1.4648e-01,  ...,  3.4180e-01,\n",
       "           -5.6641e-01,  1.8457e-01],\n",
       "          ...,\n",
       "          [ 5.4297e-01,  4.2578e-01, -6.4941e-02,  ...,  1.0391e+00,\n",
       "           -8.6719e-01,  4.1406e-01],\n",
       "          [-5.1953e-01, -4.9072e-02, -1.2422e+00,  ...,  1.2012e-01,\n",
       "           -1.3125e+00,  1.4219e+00],\n",
       "          [ 1.6309e-01,  2.2095e-02, -4.1406e-01,  ...,  4.4531e-01,\n",
       "           -1.0449e-01,  6.2891e-01]],\n",
       "\n",
       "         [[-3.5400e-02,  2.1240e-02, -1.9897e-02,  ..., -1.4099e-02,\n",
       "           -3.7354e-02,  1.5259e-02],\n",
       "          [-5.8203e-01,  1.6406e-01, -8.8281e-01,  ..., -5.7812e-01,\n",
       "            2.1094e-01, -6.6797e-01],\n",
       "          [-6.2109e-01,  4.0430e-01, -1.3516e+00,  ..., -5.0391e-01,\n",
       "            1.0010e-01, -1.0625e+00],\n",
       "          ...,\n",
       "          [ 2.6562e-01,  1.4844e-01,  2.7148e-01,  ...,  3.8672e-01,\n",
       "            6.9531e-01,  3.3789e-01],\n",
       "          [-3.0664e-01, -3.9844e-01,  1.1816e-01,  ...,  7.2656e-01,\n",
       "            1.3574e-01,  3.2617e-01],\n",
       "          [-1.0889e-01,  3.5938e-01,  4.1260e-02,  ...,  1.0596e-01,\n",
       "            5.3516e-01,  2.5586e-01]],\n",
       "\n",
       "         [[-7.8735e-03, -1.3046e-03, -1.8799e-02,  ...,  1.9226e-03,\n",
       "           -5.3711e-03, -1.1719e-02],\n",
       "          [-5.1758e-02,  7.6172e-01,  6.0938e-01,  ...,  5.7031e-01,\n",
       "           -1.6504e-01, -4.4531e-01],\n",
       "          [ 2.5781e-01, -3.3984e-01,  9.2578e-01,  ...,  1.4688e+00,\n",
       "            8.2031e-02, -2.7539e-01],\n",
       "          ...,\n",
       "          [ 1.0078e+00, -4.8828e-01,  5.0000e-01,  ...,  8.3984e-01,\n",
       "           -2.3438e-01, -9.1797e-02],\n",
       "          [ 1.0078e+00, -1.2656e+00,  8.1250e-01,  ...,  9.5703e-01,\n",
       "           -1.3984e+00, -3.0664e-01],\n",
       "          [ 4.7461e-01,  1.6797e+00,  3.8477e-01,  ...,  7.1777e-02,\n",
       "           -7.4609e-01, -1.5547e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-4.8523e-03, -6.0547e-02,  3.6163e-03,  ..., -6.6895e-02,\n",
       "            1.9043e-01, -1.1816e-01],\n",
       "          [ 1.3086e-01,  2.3633e-01,  7.9102e-02,  ..., -3.5742e-01,\n",
       "           -1.5527e-01, -1.2012e-01],\n",
       "          [-8.3984e-01,  9.7656e-01,  2.9492e-01,  ..., -1.1133e-01,\n",
       "            5.6641e-01, -2.1484e-01],\n",
       "          ...,\n",
       "          [ 1.4453e-01, -4.7070e-01, -4.3555e-01,  ..., -1.0693e-01,\n",
       "           -3.2422e-01,  2.2188e+00],\n",
       "          [-5.1758e-02,  2.2188e+00, -1.4844e+00,  ...,  1.1182e-01,\n",
       "            5.0391e-01,  5.3516e-01],\n",
       "          [-4.6484e-01,  6.4844e-01, -1.0938e+00,  ...,  3.3984e-01,\n",
       "            2.8320e-01,  3.3008e-01]],\n",
       "\n",
       "         [[-1.0729e-04, -2.3682e-02, -1.6113e-02,  ..., -1.7871e-01,\n",
       "            7.4707e-02, -2.4707e-01],\n",
       "          [ 9.9609e-02, -2.5781e-01, -4.1016e-02,  ..., -2.7539e-01,\n",
       "            3.3398e-01, -6.6833e-03],\n",
       "          [-8.3984e-02, -2.0410e-01, -7.4609e-01,  ...,  5.4688e-01,\n",
       "            8.0078e-01, -3.0273e-01],\n",
       "          ...,\n",
       "          [-7.0312e-01, -3.2812e-01, -4.2383e-01,  ...,  4.1211e-01,\n",
       "            8.5449e-02,  8.4766e-01],\n",
       "          [-9.5312e-01, -4.7266e-01,  2.9297e-03,  ...,  1.7188e-01,\n",
       "           -1.3672e+00,  4.7070e-01],\n",
       "          [-6.8750e-01, -1.5137e-01,  7.8516e-01,  ...,  1.1797e+00,\n",
       "            8.0078e-01,  1.8672e+00]],\n",
       "\n",
       "         [[-2.0294e-03,  1.8433e-02, -1.1414e-02,  ...,  7.2754e-02,\n",
       "            2.3730e-01, -8.1055e-02],\n",
       "          [ 4.8828e-01, -2.9297e-03,  3.1836e-01,  ...,  2.6406e+00,\n",
       "            7.8125e-01,  3.8672e-01],\n",
       "          [ 2.8516e-01,  7.9297e-01,  1.8652e-01,  ...,  1.8281e+00,\n",
       "            4.9072e-02,  1.6484e+00],\n",
       "          ...,\n",
       "          [-1.7578e+00, -5.7812e-01,  5.0781e-02,  ...,  1.0986e-01,\n",
       "            6.8359e-01, -1.8359e-01],\n",
       "          [-1.8555e-01,  1.5625e-01,  6.3965e-02,  ...,  7.0703e-01,\n",
       "            3.9844e+00, -1.3594e+00],\n",
       "          [-7.1484e-01, -2.5391e-01, -2.3926e-02,  ...,  6.0938e-01,\n",
       "            2.0000e+00, -1.4258e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.9775e-02, -2.0996e-02, -1.0559e-02,  ...,  1.3574e-01,\n",
       "            1.3086e-01, -6.2500e-01],\n",
       "          [-3.8281e-01,  1.6016e-01,  4.4141e-01,  ..., -1.2344e+00,\n",
       "           -2.0625e+00,  4.1016e-01],\n",
       "          [ 1.8848e-01,  2.8516e-01,  4.1992e-01,  ..., -1.0938e+00,\n",
       "           -2.2344e+00,  1.3516e+00],\n",
       "          ...,\n",
       "          [ 3.8477e-01, -2.5586e-01, -3.4180e-01,  ..., -8.6328e-01,\n",
       "           -4.2578e-01,  6.9531e-01],\n",
       "          [-1.3281e-01, -1.5625e-01,  1.6895e-01,  ..., -1.7578e+00,\n",
       "            5.2734e-01,  7.5391e-01],\n",
       "          [ 2.1680e-01, -2.1582e-01, -1.2500e-01,  ..., -1.1797e+00,\n",
       "           -7.8516e-01,  5.3125e-01]],\n",
       "\n",
       "         [[-2.0630e-02,  3.0029e-02, -3.5889e-02,  ...,  2.5938e+00,\n",
       "           -3.0469e-01, -4.8242e-01],\n",
       "          [ 1.6016e-01, -1.6797e+00,  1.5469e+00,  ..., -8.3125e+00,\n",
       "            9.5703e-01, -1.6328e+00],\n",
       "          [ 2.3438e+00, -1.7578e+00,  4.3359e-01,  ..., -9.8125e+00,\n",
       "           -8.0859e-01, -2.6562e+00],\n",
       "          ...,\n",
       "          [-1.1797e+00,  1.0312e+00, -1.0625e+00,  ..., -1.0250e+01,\n",
       "           -9.8047e-01, -3.8906e+00],\n",
       "          [-7.2656e-01,  1.7578e+00,  5.3906e-01,  ..., -9.5625e+00,\n",
       "            1.4062e+00, -3.0938e+00],\n",
       "          [-4.8828e-01,  7.8906e-01, -4.6094e-01,  ..., -9.5000e+00,\n",
       "           -4.1016e-01, -5.0000e-01]],\n",
       "\n",
       "         [[-2.9144e-03, -1.6235e-02,  3.0975e-03,  ..., -1.9141e+00,\n",
       "           -2.0386e-02,  4.7852e-01],\n",
       "          [-3.3398e-01, -1.9238e-01,  1.1914e-01,  ...,  9.3125e+00,\n",
       "           -5.3125e-01,  1.4844e-01],\n",
       "          [-1.3672e-02,  5.7031e-01,  3.9648e-01,  ...,  1.0250e+01,\n",
       "           -5.1562e-01, -6.3672e-01],\n",
       "          ...,\n",
       "          [ 3.7109e-01,  1.4297e+00,  7.1094e-01,  ...,  1.0312e+01,\n",
       "            1.0596e-01, -2.0156e+00],\n",
       "          [-8.6914e-02, -4.2969e-01,  2.9297e-01,  ...,  1.0938e+01,\n",
       "           -1.1816e-01, -1.1641e+00],\n",
       "          [ 5.6641e-02,  3.4180e-02, -4.9219e-01,  ...,  1.0500e+01,\n",
       "           -1.3516e+00, -2.1875e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-2.0874e-02, -8.1787e-03,  3.3203e-02,  ..., -3.9795e-02,\n",
       "           -1.3245e-02,  3.2471e-02],\n",
       "          [-2.2656e-01,  3.6133e-02,  4.2578e-01,  ...,  4.3945e-01,\n",
       "            2.9102e-01,  2.4219e-01],\n",
       "          [-5.1562e-01,  1.3672e-01, -4.3164e-01,  ...,  1.0498e-01,\n",
       "           -4.0625e-01,  5.8203e-01],\n",
       "          ...,\n",
       "          [-2.8516e-01,  7.2754e-02, -8.1641e-01,  ...,  6.4941e-02,\n",
       "            1.0469e+00, -9.4922e-01],\n",
       "          [ 3.3398e-01,  1.8457e-01, -7.8516e-01,  ...,  3.0859e-01,\n",
       "           -3.7109e-01, -8.2031e-01],\n",
       "          [ 1.3359e+00, -1.3184e-01, -9.1406e-01,  ..., -3.4570e-01,\n",
       "           -5.9375e-01, -1.8066e-01]],\n",
       "\n",
       "         [[-1.4404e-02,  1.5625e-02, -1.7944e-02,  ..., -1.0925e-02,\n",
       "           -4.1199e-03,  1.6602e-02],\n",
       "          [ 5.7812e-01, -8.2812e-01, -4.9805e-01,  ...,  6.9922e-01,\n",
       "           -6.1719e-01, -6.3281e-01],\n",
       "          [ 1.9531e-01, -1.0547e+00, -5.4688e-01,  ...,  5.6641e-01,\n",
       "           -8.5938e-01, -4.1602e-01],\n",
       "          ...,\n",
       "          [-1.1641e+00,  3.2471e-02,  1.0449e-01,  ...,  2.5391e-01,\n",
       "           -4.9805e-01,  9.8633e-02],\n",
       "          [-6.9922e-01, -2.6562e-01, -1.3047e+00,  ..., -7.9688e-01,\n",
       "           -1.2812e+00,  3.4570e-01],\n",
       "          [-1.0010e-02,  2.8320e-01, -6.2988e-02,  ...,  1.6211e-01,\n",
       "            2.7539e-01,  5.1562e-01]],\n",
       "\n",
       "         [[ 1.4709e-02,  1.8433e-02, -2.8931e-02,  ...,  1.2878e-02,\n",
       "           -5.0964e-03,  1.5030e-03],\n",
       "          [-8.6426e-02,  3.8574e-02, -1.5000e+00,  ..., -4.6875e-01,\n",
       "            2.7734e-01, -6.0547e-01],\n",
       "          [-1.1816e-01,  2.5977e-01, -1.2266e+00,  ..., -5.1562e-01,\n",
       "            5.3906e-01, -2.9688e-01],\n",
       "          ...,\n",
       "          [-6.6016e-01,  2.5586e-01,  1.3867e-01,  ..., -6.8359e-01,\n",
       "           -9.0625e-01,  5.8984e-01],\n",
       "          [ 1.8125e+00, -1.8750e+00, -2.7500e+00,  ...,  6.4453e-01,\n",
       "            1.0312e+00,  8.7891e-01],\n",
       "          [ 8.3594e-01, -3.5352e-01, -5.1562e-01,  ..., -2.9102e-01,\n",
       "            5.2344e-01, -5.1270e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.6367e-02,  1.7643e-04, -4.6692e-03,  ...,  4.0283e-02,\n",
       "           -2.6855e-03,  6.9275e-03],\n",
       "          [ 2.5391e-01, -5.6250e-01, -5.1562e-01,  ...,  5.3125e-01,\n",
       "            1.4893e-02,  2.6367e-01],\n",
       "          [ 1.6895e-01, -1.1719e-01, -5.5078e-01,  ...,  8.4375e-01,\n",
       "           -3.1836e-01,  1.1797e+00],\n",
       "          ...,\n",
       "          [ 4.8438e-01,  2.1094e-01,  3.4668e-02,  ..., -1.1250e+00,\n",
       "            5.0781e-01, -1.4160e-01],\n",
       "          [ 2.7539e-01, -1.0303e-01, -5.2344e-01,  ...,  3.9648e-01,\n",
       "            2.5586e-01,  2.7539e-01],\n",
       "          [ 4.3945e-01,  4.4556e-03,  2.1875e-01,  ..., -1.9824e-01,\n",
       "           -1.7871e-01, -3.7695e-01]],\n",
       "\n",
       "         [[-2.0508e-02,  1.2695e-02, -2.8198e-02,  ...,  2.8610e-04,\n",
       "           -5.9509e-03, -1.1719e-02],\n",
       "          [-5.0293e-02,  4.4727e-01, -4.5312e-01,  ..., -1.8188e-02,\n",
       "            4.2480e-02, -6.5613e-03],\n",
       "          [ 9.9609e-02,  7.3828e-01,  4.0234e-01,  ..., -2.6562e-01,\n",
       "            9.2163e-03, -1.8652e-01],\n",
       "          ...,\n",
       "          [ 8.0078e-01,  5.2344e-01, -7.6562e-01,  ..., -2.4121e-01,\n",
       "            5.1953e-01,  3.1641e-01],\n",
       "          [ 9.4238e-02,  3.3398e-01, -8.3203e-01,  ..., -2.5586e-01,\n",
       "            3.0078e-01, -6.9531e-01],\n",
       "          [ 2.8906e-01, -5.4688e-01, -2.5195e-01,  ..., -4.9414e-01,\n",
       "            1.4062e-01, -1.8359e-01]],\n",
       "\n",
       "         [[-6.0272e-04,  4.9438e-03, -2.0996e-02,  ...,  3.9062e-03,\n",
       "            1.2390e-02, -1.0193e-02],\n",
       "          [-3.7695e-01, -3.3789e-01,  5.9766e-01,  ..., -9.9609e-01,\n",
       "            2.6367e-01, -6.1035e-02],\n",
       "          [-5.4688e-01, -1.9238e-01,  7.5391e-01,  ..., -9.1797e-01,\n",
       "           -4.2773e-01, -1.6504e-01],\n",
       "          ...,\n",
       "          [-2.6758e-01,  4.9805e-01,  8.3496e-02,  ..., -6.5234e-01,\n",
       "           -4.6387e-02, -2.4219e-01],\n",
       "          [-1.9688e+00, -9.8145e-02, -9.6094e-01,  ...,  3.7305e-01,\n",
       "            3.7500e-01,  2.3560e-02],\n",
       "          [-5.3906e-01,  3.9307e-02, -1.0938e-01,  ...,  4.1504e-02,\n",
       "           -6.5625e-01,  1.5234e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.5248e-03, -1.5991e-02, -1.3184e-02,  ...,  7.7637e-02,\n",
       "           -8.5938e-02, -5.9814e-02],\n",
       "          [-8.3594e-01, -2.5391e-02, -6.5234e-01,  ...,  4.3701e-02,\n",
       "           -9.0625e-01, -1.4844e-01],\n",
       "          [-3.4766e-01, -3.1445e-01,  1.0703e+00,  ...,  1.3477e-01,\n",
       "           -2.9297e-01, -1.0859e+00],\n",
       "          ...,\n",
       "          [ 1.8828e+00, -3.1641e-01,  7.8906e-01,  ..., -3.8672e-01,\n",
       "            2.1875e+00, -1.6016e+00],\n",
       "          [-1.5332e-01,  1.3672e+00, -5.9766e-01,  ..., -1.2109e+00,\n",
       "            1.3906e+00,  7.7637e-02],\n",
       "          [-9.3750e-01,  3.5352e-01,  3.2031e-01,  ...,  8.3203e-01,\n",
       "            4.5312e-01, -9.5703e-01]],\n",
       "\n",
       "         [[ 8.1787e-03,  9.5825e-03,  4.1389e-04,  ..., -9.6436e-03,\n",
       "            2.2461e-01, -2.3340e-01],\n",
       "          [ 3.3594e-01, -1.1816e-01, -4.4531e-01,  ...,  1.7031e+00,\n",
       "           -2.3906e+00,  1.4531e+00],\n",
       "          [ 5.4297e-01, -5.8594e-01, -2.4805e-01,  ...,  2.1562e+00,\n",
       "           -1.8359e+00,  1.1875e+00],\n",
       "          ...,\n",
       "          [-1.8359e+00,  9.1797e-02, -7.0703e-01,  ...,  1.5625e+00,\n",
       "           -9.8438e-01,  4.5117e-01],\n",
       "          [-1.3184e-02, -4.9023e-01,  2.5391e-02,  ...,  1.2578e+00,\n",
       "            1.3281e+00,  1.2500e+00],\n",
       "          [ 7.1289e-02, -5.2734e-01, -3.3594e-01,  ...,  2.1562e+00,\n",
       "           -4.1016e-01,  6.4844e-01]],\n",
       "\n",
       "         [[-2.1820e-03, -6.9275e-03, -1.0193e-02,  ...,  5.5420e-02,\n",
       "            2.6758e-01, -8.1250e-01],\n",
       "          [-1.3203e+00,  4.4922e-02, -1.8066e-01,  ...,  2.4512e-01,\n",
       "           -1.9727e-01,  2.0469e+00],\n",
       "          [ 2.8320e-01, -1.6250e+00, -4.6289e-01,  ...,  6.8359e-01,\n",
       "           -8.0469e-01,  2.8750e+00],\n",
       "          ...,\n",
       "          [ 4.4141e-01,  1.5000e+00, -8.7109e-01,  ...,  1.7734e+00,\n",
       "           -1.5469e+00,  1.0703e+00],\n",
       "          [-1.3379e-01,  1.4766e+00, -1.3359e+00,  ...,  7.6953e-01,\n",
       "           -1.2109e+00,  2.4219e+00],\n",
       "          [-1.7656e+00,  1.0078e+00, -1.6406e-01,  ..., -1.7734e+00,\n",
       "           -1.2656e+00,  1.8672e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.9165e-02,  1.0742e-02, -1.8921e-02,  ...,  2.4512e-01,\n",
       "            1.0254e-01, -6.6406e-02],\n",
       "          [ 2.7344e-02, -4.7852e-01,  5.2734e-01,  ...,  2.2031e+00,\n",
       "           -2.5781e-01, -5.6641e-01],\n",
       "          [ 8.9844e-02, -3.4766e-01,  2.0312e-01,  ..., -3.2812e-01,\n",
       "           -7.8906e-01, -3.1250e-01],\n",
       "          ...,\n",
       "          [-1.1230e-02,  2.8125e-01, -3.4912e-02,  ..., -1.1797e+00,\n",
       "           -8.0469e-01,  2.3281e+00],\n",
       "          [ 8.6914e-02, -2.2852e-01,  6.8359e-02,  ...,  4.3945e-02,\n",
       "           -1.7656e+00, -2.3145e-01],\n",
       "          [-4.8828e-02, -9.0820e-02,  4.0430e-01,  ..., -8.9844e-02,\n",
       "           -1.1094e+00,  8.3984e-01]],\n",
       "\n",
       "         [[-1.0071e-02, -5.5847e-03, -1.8311e-02,  ...,  2.4048e-02,\n",
       "            2.9102e-01,  5.0537e-02],\n",
       "          [ 5.8984e-01, -1.7578e-02, -2.5000e-01,  ...,  5.8984e-01,\n",
       "           -5.1562e-01, -7.0703e-01],\n",
       "          [ 2.5391e-01, -5.3906e-01, -5.2490e-03,  ...,  1.0547e+00,\n",
       "           -2.7344e-01, -4.3555e-01],\n",
       "          ...,\n",
       "          [ 6.5625e-01,  3.2812e-01,  3.0469e-01,  ...,  5.4688e-01,\n",
       "           -2.4062e+00, -2.7734e-01],\n",
       "          [ 1.3281e+00, -6.7969e-01, -1.8164e-01,  ...,  1.2891e+00,\n",
       "           -1.4609e+00,  3.0859e-01],\n",
       "          [ 7.0312e-02, -1.1328e+00,  3.5156e-02,  ...,  2.6758e-01,\n",
       "           -1.4922e+00, -2.2888e-03]],\n",
       "\n",
       "         [[ 8.4229e-03, -9.2163e-03,  2.2461e-02,  ...,  1.4062e-01,\n",
       "            3.3203e-01, -1.8438e+00],\n",
       "          [ 8.8672e-01, -2.9688e-01, -1.4258e-01,  ...,  6.2109e-01,\n",
       "           -1.5859e+00,  6.8750e+00],\n",
       "          [-6.7383e-02,  2.4121e-01, -2.0312e-01,  ...,  1.0469e+00,\n",
       "           -1.3906e+00,  7.8750e+00],\n",
       "          ...,\n",
       "          [ 3.3984e-01, -1.1641e+00,  8.7109e-01,  ..., -4.3750e-01,\n",
       "           -1.4844e+00,  7.8750e+00],\n",
       "          [ 8.2422e-01, -8.4375e-01, -6.9531e-01,  ..., -2.8125e-01,\n",
       "           -1.5000e+00,  8.0625e+00],\n",
       "          [ 2.9688e-01, -6.1328e-01, -8.6328e-01,  ...,  1.8203e+00,\n",
       "           -1.4297e+00,  7.0938e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 9.3994e-03,  2.8931e-02,  1.8433e-02,  ..., -3.2715e-02,\n",
       "           -1.2131e-03, -8.4229e-03],\n",
       "          [-1.6699e-01, -5.2344e-01,  8.9062e-01,  ...,  6.3672e-01,\n",
       "           -1.2188e+00, -6.9141e-01],\n",
       "          [-1.8750e-01, -2.3535e-01,  3.9648e-01,  ...,  6.7969e-01,\n",
       "           -6.7578e-01, -8.3203e-01],\n",
       "          ...,\n",
       "          [ 3.6133e-01,  9.3750e-02, -4.7607e-03,  ...,  1.1215e-03,\n",
       "            1.4941e-01, -6.3281e-01],\n",
       "          [-6.3281e-01,  3.2812e-01, -5.2734e-02,  ..., -6.9531e-01,\n",
       "           -6.4844e-01, -6.0547e-01],\n",
       "          [ 2.2095e-02, -3.6914e-01,  4.9414e-01,  ..., -7.7344e-01,\n",
       "           -1.7480e-01,  9.2773e-03]],\n",
       "\n",
       "         [[ 6.1340e-03, -1.6556e-03, -2.5940e-03,  ..., -1.4526e-02,\n",
       "            4.7302e-03,  5.1270e-03],\n",
       "          [ 6.7188e-01,  1.1719e+00,  1.7676e-01,  ..., -4.3164e-01,\n",
       "           -6.2891e-01,  9.4531e-01],\n",
       "          [ 1.3672e+00,  6.5234e-01,  2.6953e-01,  ..., -5.1562e-01,\n",
       "           -2.8320e-01,  9.8828e-01],\n",
       "          ...,\n",
       "          [ 3.6914e-01, -3.5938e-01,  4.5703e-01,  ..., -1.9453e+00,\n",
       "           -1.0107e-01,  1.3574e-01],\n",
       "          [ 3.7695e-01,  1.1133e-01, -5.5664e-02,  ..., -7.4219e-01,\n",
       "            5.5859e-01,  9.8047e-01],\n",
       "          [ 9.4141e-01,  2.5781e-01,  6.1523e-02,  ..., -5.1562e-01,\n",
       "           -2.9419e-02, -1.4062e-01]],\n",
       "\n",
       "         [[ 1.1963e-02,  1.5991e-02,  1.8066e-02,  ..., -4.7913e-03,\n",
       "            1.3367e-02, -2.5024e-02],\n",
       "          [-9.9609e-02,  1.0469e+00,  2.9883e-01,  ..., -4.1211e-01,\n",
       "            4.3750e-01,  3.5156e-01],\n",
       "          [-8.3984e-01,  1.1484e+00,  1.9336e-01,  ..., -3.8281e-01,\n",
       "            1.5723e-01,  1.3359e+00],\n",
       "          ...,\n",
       "          [ 6.6797e-01,  1.5234e+00,  7.7734e-01,  ..., -8.4766e-01,\n",
       "            1.1719e+00,  3.4766e-01],\n",
       "          [ 8.2812e-01,  5.5078e-01, -1.3477e-01,  ..., -1.4141e+00,\n",
       "            5.8203e-01,  5.4688e-01],\n",
       "          [ 5.8105e-02,  2.3242e-01,  7.3828e-01,  ..., -1.1172e+00,\n",
       "           -1.4453e-01, -1.1641e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.0742e-02,  3.9062e-03,  2.1577e-05,  ..., -4.4250e-03,\n",
       "           -8.7280e-03, -8.4229e-03],\n",
       "          [ 4.9023e-01, -2.6758e-01,  7.9297e-01,  ..., -2.5586e-01,\n",
       "           -2.0703e-01,  3.9453e-01],\n",
       "          [-9.5215e-02,  1.5991e-02,  8.8672e-01,  ...,  1.8750e-01,\n",
       "           -1.9043e-02,  7.2266e-01],\n",
       "          ...,\n",
       "          [ 1.4609e+00, -3.7305e-01, -6.8848e-02,  ...,  1.4062e-01,\n",
       "            8.7891e-02,  3.7109e-02],\n",
       "          [-9.1797e-01, -5.2246e-02, -5.4688e-01,  ...,  1.6328e+00,\n",
       "            1.0312e+00,  1.4160e-01],\n",
       "          [-8.4766e-01,  1.0840e-01,  8.0078e-01,  ...,  1.0205e-01,\n",
       "           -6.8359e-01, -9.2188e-01]],\n",
       "\n",
       "         [[-3.1494e-02,  4.3335e-03, -2.3438e-02,  ..., -2.5024e-02,\n",
       "            5.3711e-03, -5.5420e-02],\n",
       "          [ 1.2354e-01, -1.4844e-01,  1.0234e+00,  ...,  2.6758e-01,\n",
       "            8.9844e-01, -1.2578e+00],\n",
       "          [ 5.3906e-01, -5.7031e-01,  7.8906e-01,  ...,  3.3594e-01,\n",
       "            3.8086e-01, -1.1797e+00],\n",
       "          ...,\n",
       "          [ 5.5859e-01, -1.2969e+00,  2.1680e-01,  ...,  2.4023e-01,\n",
       "           -2.4048e-02,  3.4961e-01],\n",
       "          [-3.0273e-01, -5.7812e-01,  1.0234e+00,  ..., -7.5391e-01,\n",
       "           -1.5391e+00, -6.8750e-01],\n",
       "          [ 2.3535e-01, -1.4453e+00,  9.9609e-01,  ...,  8.5938e-01,\n",
       "            1.4941e-01,  3.5156e-01]],\n",
       "\n",
       "         [[-2.1057e-03, -9.8267e-03,  2.8687e-02,  ...,  7.5340e-05,\n",
       "           -3.4668e-02,  1.8311e-02],\n",
       "          [-2.5781e-01, -1.5918e-01, -1.4844e-01,  ..., -5.3516e-01,\n",
       "           -5.2344e-01, -3.4961e-01],\n",
       "          [ 3.0078e-01, -8.2520e-02, -4.0039e-01,  ..., -1.3281e-01,\n",
       "           -1.1641e+00, -3.8818e-02],\n",
       "          ...,\n",
       "          [ 1.4609e+00, -1.4609e+00,  2.2827e-02,  ...,  3.5938e-01,\n",
       "            1.4258e-01,  1.5547e+00],\n",
       "          [ 7.8125e-01, -4.1016e-01,  5.1514e-02,  ..., -2.8320e-01,\n",
       "            1.4219e+00,  9.3750e-01],\n",
       "          [ 1.1406e+00, -1.6211e-01, -8.2031e-01,  ..., -5.1172e-01,\n",
       "           -1.0547e+00,  1.1797e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.2207e-02,  1.3916e-02,  3.8757e-03,  ...,  1.1328e-01,\n",
       "            4.5312e-01, -1.9238e-01],\n",
       "          [-5.4688e-01,  8.5156e-01,  6.0156e-01,  ..., -2.2888e-03,\n",
       "           -1.1953e+00, -1.0986e-01],\n",
       "          [ 2.1680e-01,  2.8711e-01,  4.8633e-01,  ...,  4.8633e-01,\n",
       "            3.0884e-02,  7.5781e-01],\n",
       "          ...,\n",
       "          [-8.5547e-01, -3.7109e-01, -7.5781e-01,  ..., -1.2988e-01,\n",
       "           -1.7344e+00,  9.4531e-01],\n",
       "          [-7.8125e-02, -6.5430e-02, -2.0410e-01,  ...,  6.4844e-01,\n",
       "           -1.1094e+00,  8.9453e-01],\n",
       "          [ 4.2188e-01, -4.7119e-02, -2.4414e-02,  ..., -3.9062e-01,\n",
       "           -2.2812e+00,  5.1953e-01]],\n",
       "\n",
       "         [[-2.2583e-02,  5.2795e-03,  5.6152e-03,  ..., -9.7656e-02,\n",
       "            7.4707e-02, -9.0332e-02],\n",
       "          [ 1.5156e+00,  4.4531e-01,  5.7812e-01,  ..., -9.7656e-01,\n",
       "            1.0547e+00, -1.0781e+00],\n",
       "          [ 7.9688e-01,  2.9102e-01,  3.9648e-01,  ..., -5.1953e-01,\n",
       "            9.4922e-01, -1.4062e+00],\n",
       "          ...,\n",
       "          [-1.2031e+00,  4.7461e-01,  3.1641e-01,  ..., -1.3047e+00,\n",
       "            6.4844e-01, -1.0312e+00],\n",
       "          [ 3.3594e-01, -3.7305e-01,  1.0938e-01,  ...,  8.2422e-01,\n",
       "           -1.7500e+00, -1.9219e+00],\n",
       "          [ 7.8125e-01,  2.4023e-01, -1.4160e-02,  ..., -2.6406e+00,\n",
       "            8.4473e-02, -7.2656e-01]],\n",
       "\n",
       "         [[-2.9602e-03, -1.5564e-02,  3.3951e-04,  ...,  4.2578e-01,\n",
       "            2.7148e-01, -1.9336e-01],\n",
       "          [-1.0000e+00,  2.3438e-01,  5.7812e-01,  ...,  1.0449e-01,\n",
       "           -6.1328e-01,  1.3438e+00],\n",
       "          [-4.6094e-01,  7.5391e-01, -4.2773e-01,  ..., -1.0078e+00,\n",
       "           -1.6016e+00,  7.1875e-01],\n",
       "          ...,\n",
       "          [-4.1211e-01, -6.5625e-01,  1.4941e-01,  ...,  4.6875e-01,\n",
       "           -1.6602e-02,  1.0254e-01],\n",
       "          [ 1.3516e+00,  4.1406e-01, -1.1562e+00,  ..., -1.5039e-01,\n",
       "           -1.6406e+00,  8.7891e-01],\n",
       "          [-6.2012e-02, -2.0312e-01, -7.3242e-02,  ..., -5.1953e-01,\n",
       "           -1.2969e+00, -1.6699e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.2512e-02,  1.8066e-02, -5.9204e-03,  ..., -1.5312e+00,\n",
       "            1.8359e-01, -5.0000e-01],\n",
       "          [-1.0000e+00, -4.1992e-01, -2.3438e-01,  ...,  3.5469e+00,\n",
       "           -1.7344e+00, -2.5781e-01],\n",
       "          [-1.2031e+00, -2.7344e-01, -9.9609e-01,  ...,  4.6250e+00,\n",
       "           -1.1172e+00,  1.9922e-01],\n",
       "          ...,\n",
       "          [ 5.9375e-01,  7.9297e-01,  4.4727e-01,  ...,  4.0000e+00,\n",
       "           -9.2578e-01, -1.5469e+00],\n",
       "          [ 1.2422e+00, -2.9102e-01,  2.8320e-01,  ...,  4.7188e+00,\n",
       "           -7.0312e-01, -1.7344e+00],\n",
       "          [ 6.4844e-01,  6.0156e-01,  1.0234e+00,  ...,  4.1562e+00,\n",
       "           -1.3438e+00, -1.8125e+00]],\n",
       "\n",
       "         [[ 1.6602e-02, -1.5625e-02, -1.5015e-02,  ..., -1.7773e-01,\n",
       "            6.1035e-02, -2.1973e-03],\n",
       "          [ 2.9297e-01, -4.8047e-01, -3.9062e-03,  ..., -2.4121e-01,\n",
       "           -1.4766e+00,  1.5938e+00],\n",
       "          [ 5.7812e-01,  1.6895e-01, -4.5312e-01,  ...,  3.0859e-01,\n",
       "           -7.5391e-01,  1.8828e+00],\n",
       "          ...,\n",
       "          [-1.5938e+00,  4.3945e-01, -3.1250e-02,  ...,  6.9922e-01,\n",
       "           -2.8125e-01,  1.0469e+00],\n",
       "          [-2.9062e+00, -1.3750e+00, -4.5898e-01,  ...,  9.6875e-01,\n",
       "            9.5312e-01,  1.0859e+00],\n",
       "          [-9.6484e-01, -1.4844e-01, -4.7656e-01,  ..., -5.0391e-01,\n",
       "           -1.7969e-01, -4.3750e-01]],\n",
       "\n",
       "         [[-2.5269e-02,  1.8799e-02,  8.1062e-05,  ..., -3.0273e-01,\n",
       "           -8.3984e-02, -1.5137e-01],\n",
       "          [ 1.8438e+00, -1.6406e+00,  8.9844e-01,  ...,  8.2031e-01,\n",
       "            4.2236e-02, -1.5137e-02],\n",
       "          [ 9.2188e-01, -1.7266e+00,  4.1016e-01,  ...,  1.4609e+00,\n",
       "            6.7188e-01, -4.5703e-01],\n",
       "          ...,\n",
       "          [-1.4531e+00,  2.4219e+00, -1.5234e+00,  ...,  4.8584e-02,\n",
       "            7.3828e-01,  3.7891e-01],\n",
       "          [ 1.3047e+00,  2.1250e+00, -3.7891e-01,  ...,  1.9766e+00,\n",
       "           -3.8672e-01,  6.9922e-01],\n",
       "          [ 1.3125e+00, -2.4414e-01,  4.7266e-01,  ...,  2.7734e-01,\n",
       "            2.3828e-01,  6.3965e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.4954e-02, -6.3171e-03, -9.4604e-03,  ...,  7.8735e-03,\n",
       "            1.0803e-02, -1.9775e-02],\n",
       "          [ 5.4443e-02, -5.2734e-01,  7.3438e-01,  ..., -9.1406e-01,\n",
       "           -9.1406e-01,  1.3594e+00],\n",
       "          [-5.9375e-01, -4.1406e-01,  5.3906e-01,  ..., -1.7773e-01,\n",
       "           -1.2188e+00,  1.3359e+00],\n",
       "          ...,\n",
       "          [-1.3672e+00, -1.3594e+00,  6.6016e-01,  ..., -2.7734e-01,\n",
       "            9.3750e-01,  6.0425e-03],\n",
       "          [-9.4531e-01, -7.8613e-02, -7.9297e-01,  ...,  1.0781e+00,\n",
       "            7.5391e-01, -5.3906e-01],\n",
       "          [-1.1797e+00, -5.7422e-01,  6.4062e-01,  ...,  2.2852e-01,\n",
       "           -2.4170e-02, -2.6172e-01]],\n",
       "\n",
       "         [[-4.8218e-03, -2.4872e-03,  2.4780e-02,  ...,  6.7444e-03,\n",
       "            5.4932e-03,  1.1169e-02],\n",
       "          [-9.9121e-02,  1.1621e-01, -1.6309e-01,  ...,  4.2383e-01,\n",
       "           -1.0547e+00, -2.6953e-01],\n",
       "          [-4.9316e-02,  4.1797e-01, -8.7500e-01,  ...,  8.5156e-01,\n",
       "           -6.7871e-02,  1.2598e-01],\n",
       "          ...,\n",
       "          [-7.8125e-01,  1.6797e-01,  7.6953e-01,  ...,  1.3125e+00,\n",
       "           -6.5625e-01,  1.4160e-01],\n",
       "          [-7.9297e-01, -1.1250e+00,  8.5156e-01,  ..., -1.2891e-01,\n",
       "            3.5547e-01,  5.8203e-01],\n",
       "          [-1.3281e-01,  7.7734e-01, -1.4160e-02,  ...,  6.8848e-02,\n",
       "            2.0215e-01, -3.5352e-01]],\n",
       "\n",
       "         [[ 8.4839e-03,  1.1658e-02,  2.8687e-02,  ..., -2.3560e-02,\n",
       "            2.7344e-02,  3.8330e-02],\n",
       "          [ 1.1641e+00,  6.9531e-01,  1.2988e-01,  ...,  8.3203e-01,\n",
       "            9.7656e-01, -3.3203e-01],\n",
       "          [ 4.1992e-01,  9.0234e-01,  4.3945e-01,  ...,  1.2188e+00,\n",
       "            9.6094e-01, -7.6172e-01],\n",
       "          ...,\n",
       "          [ 3.1006e-02, -1.4941e-01, -1.4531e+00,  ...,  5.0000e-01,\n",
       "           -3.5742e-01, -8.5938e-02],\n",
       "          [ 3.1641e-01, -2.7148e-01, -1.6309e-01,  ...,  2.6562e-01,\n",
       "           -7.0703e-01, -6.5234e-01],\n",
       "          [ 2.0000e+00,  7.2266e-01,  2.7734e-01,  ...,  1.0703e+00,\n",
       "            2.8516e-01,  1.1797e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.3489e-02,  6.9580e-03, -1.1902e-02,  ..., -8.1787e-03,\n",
       "            9.9487e-03, -3.6316e-03],\n",
       "          [ 3.2031e-01, -3.3594e-01, -5.0781e-01,  ...,  3.0078e-01,\n",
       "           -1.2988e-01,  1.2734e+00],\n",
       "          [-5.9375e-01, -4.7852e-01, -1.0078e+00,  ...,  2.5757e-02,\n",
       "            1.8848e-01,  8.5156e-01],\n",
       "          ...,\n",
       "          [ 1.5918e-01,  3.1250e-01, -3.4668e-02,  ...,  2.1973e-01,\n",
       "            1.1484e+00, -9.0625e-01],\n",
       "          [-5.1172e-01,  4.1211e-01, -5.5469e-01,  ...,  2.1406e+00,\n",
       "            3.8867e-01, -1.8281e+00],\n",
       "          [ 8.0859e-01,  7.5391e-01,  4.8047e-01,  ...,  6.0156e-01,\n",
       "            3.6914e-01, -3.1250e-01]],\n",
       "\n",
       "         [[ 2.5513e-02,  1.8921e-02, -1.8066e-02,  ..., -5.9082e-02,\n",
       "            7.5195e-02, -2.4414e-02],\n",
       "          [-4.4727e-01, -4.3945e-01,  4.4336e-01,  ..., -1.2031e+00,\n",
       "           -1.1406e+00,  1.2188e+00],\n",
       "          [-1.1562e+00,  6.9531e-01,  2.3828e-01,  ..., -1.2969e+00,\n",
       "           -8.9062e-01,  8.7500e-01],\n",
       "          ...,\n",
       "          [ 5.7422e-01, -1.7090e-01,  6.0547e-02,  ..., -4.7070e-01,\n",
       "           -5.6641e-01,  6.7969e-01],\n",
       "          [ 5.9375e-01,  1.5527e-01, -6.7578e-01,  ...,  1.1250e+00,\n",
       "           -2.2031e+00,  7.3828e-01],\n",
       "          [-1.9897e-02, -7.5000e-01, -6.0156e-01,  ...,  2.7734e-01,\n",
       "           -1.1328e+00,  1.2969e+00]],\n",
       "\n",
       "         [[ 2.5757e-02, -6.3324e-04, -1.2634e-02,  ..., -7.3853e-03,\n",
       "           -1.4893e-02, -8.8501e-03],\n",
       "          [-4.3555e-01, -4.4434e-02, -2.2363e-01,  ..., -3.2031e-01,\n",
       "            7.8125e-01, -1.0547e+00],\n",
       "          [-2.5195e-01, -1.0449e-01, -8.8672e-01,  ..., -8.2812e-01,\n",
       "            9.6484e-01, -2.2852e-01],\n",
       "          ...,\n",
       "          [-7.9956e-03,  8.3984e-01,  2.8320e-01,  ..., -5.0000e-01,\n",
       "           -1.3516e+00, -3.4375e-01],\n",
       "          [-1.0312e+00, -2.3926e-01,  6.7578e-01,  ..., -3.8281e-01,\n",
       "           -5.4688e-01,  4.8828e-02],\n",
       "          [ 5.8203e-01,  4.0625e-01, -1.6895e-01,  ..., -1.5918e-01,\n",
       "           -1.4062e-01,  1.7773e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 2.1057e-03,  3.7109e-02, -1.8082e-03,  ...,  3.0664e-01,\n",
       "            6.5918e-02, -1.9434e-01],\n",
       "          [-8.1250e-01, -9.1016e-01, -1.2344e+00,  ..., -9.7656e-01,\n",
       "           -8.3984e-01,  3.1445e-01],\n",
       "          [-8.9844e-01, -8.7109e-01,  9.1309e-02,  ..., -5.4688e-01,\n",
       "           -2.7344e-01, -4.2773e-01],\n",
       "          ...,\n",
       "          [-3.2227e-02, -4.1602e-01,  1.0078e+00,  ...,  1.2207e-02,\n",
       "            8.9453e-01,  9.3750e-02],\n",
       "          [ 8.0078e-01, -1.0156e+00, -2.5586e-01,  ..., -1.3984e+00,\n",
       "            1.6172e+00, -5.8203e-01],\n",
       "          [-7.5000e-01,  3.6719e-01, -1.0234e+00,  ...,  2.9883e-01,\n",
       "            1.6328e+00,  1.6211e-01]],\n",
       "\n",
       "         [[-7.2098e-04,  1.3062e-02,  1.0254e-02,  ...,  6.2109e-01,\n",
       "            3.0273e-01,  5.2734e-01],\n",
       "          [ 1.5391e+00, -5.3906e-01, -5.3711e-02,  ..., -2.5156e+00,\n",
       "           -3.7695e-01, -2.7812e+00],\n",
       "          [ 6.6406e-01, -1.2207e-01,  5.1514e-02,  ..., -1.8516e+00,\n",
       "            6.0938e-01, -2.9219e+00],\n",
       "          ...,\n",
       "          [-1.2344e+00,  5.7812e-01,  9.7168e-02,  ..., -1.9844e+00,\n",
       "            4.4922e-01, -2.8750e+00],\n",
       "          [-4.6875e-01, -7.5000e-01,  9.5703e-01,  ..., -5.6250e-01,\n",
       "            1.0859e+00, -1.8516e+00],\n",
       "          [ 1.3984e+00,  4.7363e-02, -1.7383e-01,  ..., -2.6719e+00,\n",
       "           -1.1953e+00, -2.7812e+00]],\n",
       "\n",
       "         [[ 1.4954e-02,  1.6357e-02, -1.0605e-03,  ...,  4.6387e-02,\n",
       "            1.2656e+00, -1.2656e+00],\n",
       "          [-1.9922e-01,  2.4902e-02,  1.3281e-01,  ..., -1.1406e+00,\n",
       "           -1.6875e+00,  3.5469e+00],\n",
       "          [ 1.8750e-01, -1.1230e-01,  5.8594e-01,  ..., -6.4453e-02,\n",
       "           -1.4375e+00,  4.0938e+00],\n",
       "          ...,\n",
       "          [-1.0303e-01, -1.1719e-01, -1.3428e-02,  ..., -1.3438e+00,\n",
       "           -2.6719e+00,  3.0938e+00],\n",
       "          [ 1.0645e-01,  1.0547e-01,  2.2266e-01,  ..., -1.4062e+00,\n",
       "           -3.3750e+00,  2.5781e+00],\n",
       "          [ 2.6172e-01, -6.7383e-02, -9.0820e-02,  ...,  2.2266e-01,\n",
       "           -2.8281e+00,  4.8750e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-4.1504e-03,  3.3569e-03, -7.5378e-03,  ..., -1.6016e-01,\n",
       "           -2.5391e-01, -1.4551e-01],\n",
       "          [ 7.8125e-01,  3.2617e-01,  3.9844e-01,  ..., -8.3203e-01,\n",
       "            5.1953e-01, -4.6484e-01],\n",
       "          [-5.9375e-01, -3.5742e-01,  1.2598e-01,  ...,  3.1738e-02,\n",
       "           -9.4141e-01,  6.1719e-01],\n",
       "          ...,\n",
       "          [-1.6406e-01, -3.6719e-01,  8.5938e-01,  ..., -1.8066e-01,\n",
       "            1.1406e+00,  7.9590e-02],\n",
       "          [ 1.5156e+00,  5.0391e-01, -5.5469e-01,  ..., -7.2754e-02,\n",
       "            1.6562e+00,  8.5449e-02],\n",
       "          [ 2.1719e+00,  7.4219e-01, -3.3984e-01,  ...,  7.8906e-01,\n",
       "            4.6289e-01,  4.1016e-01]],\n",
       "\n",
       "         [[-1.1597e-02,  1.2329e-02, -2.7100e-02,  ...,  9.9219e-01,\n",
       "           -1.3965e-01,  4.0039e-01],\n",
       "          [ 5.6250e-01, -6.4453e-01,  5.8594e-01,  ..., -3.3906e+00,\n",
       "            1.9766e+00, -4.6387e-02],\n",
       "          [ 1.2402e-01,  2.3438e-01,  4.7070e-01,  ..., -3.9531e+00,\n",
       "            1.8906e+00, -3.4570e-01],\n",
       "          ...,\n",
       "          [-3.4668e-02,  3.9453e-01, -8.9844e-02,  ..., -2.9688e+00,\n",
       "            2.2812e+00, -8.4766e-01],\n",
       "          [-1.1035e-01,  1.1328e-01, -4.4336e-01,  ..., -1.3984e+00,\n",
       "            3.1250e-01,  4.5654e-02],\n",
       "          [-7.2656e-01, -1.2812e+00, -1.2988e-01,  ..., -3.0000e+00,\n",
       "            1.4062e+00,  1.9434e-01]],\n",
       "\n",
       "         [[ 7.2327e-03, -7.4463e-03,  1.3855e-02,  ...,  2.3242e-01,\n",
       "           -3.1250e-01, -7.9297e-01],\n",
       "          [ 4.4434e-02, -1.4746e-01, -1.2402e-01,  ..., -1.4844e-01,\n",
       "            1.3672e+00,  2.1406e+00],\n",
       "          [-1.5234e-01,  1.8164e-01, -7.8125e-01,  ..., -1.3125e+00,\n",
       "            1.3203e+00,  2.6719e+00],\n",
       "          ...,\n",
       "          [ 6.0303e-02,  2.8125e-01,  1.2695e-01,  ..., -8.8672e-01,\n",
       "            6.4062e-01,  3.9375e+00],\n",
       "          [-1.0938e-01,  4.7363e-02,  4.6484e-01,  ..., -3.0078e-01,\n",
       "            1.5312e+00,  3.6875e+00],\n",
       "          [ 3.3203e-01, -4.3750e-01, -4.2969e-01,  ..., -1.5781e+00,\n",
       "            1.9062e+00,  3.0156e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-7.8125e-03,  3.5553e-03, -1.5076e-02,  ..., -8.1177e-03,\n",
       "            2.0409e-04, -9.2773e-03],\n",
       "          [-8.1250e-01,  1.1953e+00,  5.0000e-01,  ...,  1.2344e+00,\n",
       "            1.2578e+00, -7.3047e-01],\n",
       "          [-4.6289e-01, -1.6406e-01,  4.6680e-01,  ...,  4.9023e-01,\n",
       "            1.0469e+00, -3.3398e-01],\n",
       "          ...,\n",
       "          [-2.0898e-01, -8.3984e-01, -3.1494e-02,  ...,  1.0742e-01,\n",
       "           -4.7656e-01,  7.0703e-01],\n",
       "          [ 2.1680e-01, -1.7773e-01,  9.3359e-01,  ..., -5.0781e-01,\n",
       "            8.1543e-02, -4.5410e-02],\n",
       "          [ 5.1953e-01,  3.6523e-01,  6.4062e-01,  ...,  1.5938e+00,\n",
       "            9.3359e-01, -1.1172e+00]],\n",
       "\n",
       "         [[ 6.5994e-04,  4.7913e-03,  3.0823e-03,  ...,  1.7014e-03,\n",
       "            4.5471e-03,  1.8677e-02],\n",
       "          [-5.2734e-02,  8.0078e-01,  3.3594e-01,  ..., -6.0730e-03,\n",
       "           -4.1211e-01, -3.4961e-01],\n",
       "          [-6.4453e-01,  5.5859e-01,  5.3516e-01,  ..., -1.4746e-01,\n",
       "           -7.4219e-01, -2.6562e-01],\n",
       "          ...,\n",
       "          [-6.5234e-01,  6.0791e-02, -4.8633e-01,  ..., -1.2598e-01,\n",
       "            6.2891e-01,  3.0078e-01],\n",
       "          [-1.4141e+00, -4.8047e-01, -1.4844e-01,  ...,  1.2578e+00,\n",
       "           -1.0107e-01, -3.9648e-01],\n",
       "          [-5.5469e-01,  1.2031e+00, -5.5078e-01,  ..., -3.7695e-01,\n",
       "           -1.8945e-01,  8.7500e-01]],\n",
       "\n",
       "         [[-2.6123e-02, -2.0142e-02, -3.1738e-02,  ...,  1.1536e-02,\n",
       "            3.5889e-02, -5.7129e-02],\n",
       "          [-1.0859e+00, -2.4609e-01, -2.2656e-01,  ...,  5.0391e-01,\n",
       "           -6.9531e-01,  3.1250e-01],\n",
       "          [-7.4219e-01, -7.5684e-02, -7.5781e-01,  ..., -2.1680e-01,\n",
       "           -1.2422e+00,  1.8555e-01],\n",
       "          ...,\n",
       "          [ 1.3770e-01,  6.5625e-01,  6.0547e-01,  ...,  7.0312e-01,\n",
       "            1.4062e-01, -2.2363e-01],\n",
       "          [-2.9883e-01,  9.7266e-01,  1.8848e-01,  ...,  9.8438e-01,\n",
       "            2.9883e-01, -6.7578e-01],\n",
       "          [-4.7852e-01,  2.8125e-01,  6.6406e-01,  ..., -6.8359e-02,\n",
       "            8.6426e-02,  1.2988e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.0742e-02,  2.9053e-02, -1.4404e-02,  ..., -3.6621e-02,\n",
       "           -5.3467e-02, -4.0527e-02],\n",
       "          [ 3.1250e-01,  6.7969e-01, -1.6016e-01,  ...,  8.2812e-01,\n",
       "            2.0801e-01,  6.4453e-02],\n",
       "          [ 7.3047e-01,  3.1445e-01, -5.5469e-01,  ...,  2.6367e-01,\n",
       "            4.6289e-01,  5.3125e-01],\n",
       "          ...,\n",
       "          [ 1.3750e+00, -7.2656e-01,  1.1719e+00,  ...,  1.1250e+00,\n",
       "            7.9688e-01,  5.7812e-01],\n",
       "          [ 7.5391e-01,  2.1191e-01,  5.9375e-01,  ..., -2.6562e-01,\n",
       "            1.0938e+00, -1.5938e+00],\n",
       "          [ 5.7422e-01, -3.5742e-01, -3.4375e-01,  ...,  1.0078e+00,\n",
       "            3.7891e-01,  2.0312e-01]],\n",
       "\n",
       "         [[-1.1749e-03,  1.4771e-02,  9.3994e-03,  ...,  1.0803e-02,\n",
       "           -1.5259e-02, -3.2715e-02],\n",
       "          [-4.2383e-01,  7.1484e-01,  5.6152e-02,  ..., -2.1851e-02,\n",
       "            8.0859e-01,  6.1035e-02],\n",
       "          [-9.2188e-01,  4.4922e-01, -5.1172e-01,  ...,  3.3594e-01,\n",
       "            6.7578e-01, -2.5977e-01],\n",
       "          ...,\n",
       "          [ 2.2656e-01, -6.7578e-01, -1.5259e-02,  ...,  2.9907e-02,\n",
       "           -1.8262e-01,  4.6289e-01],\n",
       "          [ 7.3438e-01,  1.1035e-01,  4.0430e-01,  ...,  6.7578e-01,\n",
       "           -1.2578e+00,  1.1016e+00],\n",
       "          [ 3.2617e-01,  2.7148e-01,  5.0391e-01,  ..., -4.4922e-01,\n",
       "            1.9434e-01,  3.1445e-01]],\n",
       "\n",
       "         [[-1.3794e-02, -2.6733e-02, -1.6357e-02,  ...,  1.1719e-02,\n",
       "            1.8768e-03, -1.2085e-02],\n",
       "          [-1.7773e-01,  3.4375e-01, -4.3701e-02,  ...,  1.4688e+00,\n",
       "           -2.8320e-01, -3.3398e-01],\n",
       "          [ 6.4844e-01,  1.7188e-01, -3.0859e-01,  ...,  1.2969e+00,\n",
       "            2.5195e-01,  2.0508e-01],\n",
       "          ...,\n",
       "          [ 1.0078e+00,  2.5781e-01,  7.1875e-01,  ...,  1.2422e+00,\n",
       "           -9.3262e-02,  7.4707e-02],\n",
       "          [ 1.0469e+00,  1.8945e-01,  7.4609e-01,  ...,  2.2852e-01,\n",
       "           -8.6719e-01,  2.6758e-01],\n",
       "          [ 1.1875e+00,  9.6484e-01,  4.1016e-01,  ...,  1.0000e+00,\n",
       "           -3.9258e-01, -8.2031e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 8.3008e-03, -1.3855e-02, -2.3071e-02,  ..., -2.5146e-02,\n",
       "           -1.2061e-01,  8.7280e-03],\n",
       "          [-1.3906e+00, -2.0020e-01,  5.3906e-01,  ...,  4.3750e-01,\n",
       "           -3.9844e-01, -4.4922e-02],\n",
       "          [-8.0078e-01,  4.1602e-01,  1.3281e-01,  ..., -5.0391e-01,\n",
       "           -2.0312e-01, -2.2969e+00],\n",
       "          ...,\n",
       "          [ 6.7969e-01,  6.0059e-02, -1.4453e-01,  ...,  9.8047e-01,\n",
       "            1.6328e+00, -2.9688e+00],\n",
       "          [ 1.0625e+00,  7.2656e-01, -7.2266e-01,  ...,  7.0312e-01,\n",
       "           -4.3555e-01, -1.6328e+00],\n",
       "          [-9.2969e-01,  6.0938e-01, -3.8477e-01,  ..., -1.7090e-01,\n",
       "            1.0156e+00,  1.3594e+00]],\n",
       "\n",
       "         [[-1.2436e-03, -4.3640e-03, -5.7983e-03,  ..., -2.9541e-02,\n",
       "            3.2715e-02, -2.6758e-01],\n",
       "          [ 1.7969e-01, -5.4297e-01,  4.9805e-01,  ..., -1.1875e+00,\n",
       "            1.7266e+00, -4.9414e-01],\n",
       "          [ 3.4766e-01,  1.9043e-02,  5.3125e-01,  ..., -2.5938e+00,\n",
       "            1.8203e+00,  1.1484e+00],\n",
       "          ...,\n",
       "          [ 3.5938e-01, -7.7148e-02,  1.4941e-01,  ...,  9.2578e-01,\n",
       "            7.7734e-01, -3.5469e+00],\n",
       "          [ 7.5391e-01, -2.5156e+00,  7.2266e-02,  ...,  9.9219e-01,\n",
       "           -1.1768e-01,  6.7969e-01],\n",
       "          [ 1.8457e-01, -4.7461e-01, -2.3828e-01,  ...,  1.0303e-01,\n",
       "            2.4414e-01,  3.2617e-01]],\n",
       "\n",
       "         [[ 1.1353e-02, -6.5002e-03,  1.3855e-02,  ..., -4.3213e-02,\n",
       "            3.0762e-02,  2.2266e-01],\n",
       "          [ 1.3750e+00, -2.5195e-01,  7.6953e-01,  ..., -1.0938e+00,\n",
       "           -3.9258e-01, -1.0391e+00],\n",
       "          [ 4.2383e-01,  7.6953e-01,  2.1094e-01,  ..., -9.9219e-01,\n",
       "           -4.3945e-01, -2.7734e-01],\n",
       "          ...,\n",
       "          [-1.4258e-01, -9.4141e-01, -1.7266e+00,  ..., -9.4531e-01,\n",
       "            1.7383e-01, -5.0781e-01],\n",
       "          [ 7.6562e-01, -1.2812e+00, -1.5234e+00,  ...,  5.7812e-01,\n",
       "           -1.1719e-01, -9.9609e-01],\n",
       "          [ 9.4922e-01, -6.7188e-01,  5.5469e-01,  ..., -1.1641e+00,\n",
       "           -9.8828e-01, -6.6406e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.9775e-02, -2.5146e-02, -1.7456e-02,  ..., -3.5742e-01,\n",
       "           -5.2734e-01,  6.0059e-02],\n",
       "          [-3.7500e-01, -6.7188e-01,  6.9922e-01,  ..., -2.0781e+00,\n",
       "           -6.8359e-01,  1.2656e+00],\n",
       "          [-3.0859e-01, -1.2578e+00,  9.8633e-02,  ..., -1.8203e+00,\n",
       "            2.5586e-01,  6.2109e-01],\n",
       "          ...,\n",
       "          [-8.6328e-01,  4.2773e-01, -2.3438e-02,  ...,  9.5703e-01,\n",
       "            2.6758e-01, -8.6719e-01],\n",
       "          [-4.7852e-01, -4.7656e-01,  1.1719e+00,  ...,  9.3359e-01,\n",
       "           -1.9453e+00, -2.3535e-01],\n",
       "          [ 4.4727e-01,  6.5234e-01, -1.6602e-02,  ...,  8.8867e-02,\n",
       "           -8.3203e-01, -1.9219e+00]],\n",
       "\n",
       "         [[ 1.5869e-02, -1.6235e-02, -1.5869e-02,  ..., -1.3184e-02,\n",
       "            1.8311e-02,  1.3574e-01],\n",
       "          [-2.9688e-01, -2.3438e-01, -1.1562e+00,  ...,  1.4297e+00,\n",
       "           -5.3125e-01, -1.1016e+00],\n",
       "          [-1.2891e-01, -1.1641e+00, -2.4219e-01,  ...,  1.4531e+00,\n",
       "           -1.2734e+00, -1.1816e-01],\n",
       "          ...,\n",
       "          [ 4.0234e-01, -2.2168e-01,  4.8047e-01,  ...,  2.6406e+00,\n",
       "           -5.2734e-02, -1.6328e+00],\n",
       "          [ 1.8848e-01, -1.4355e-01,  8.1641e-01,  ...,  2.4688e+00,\n",
       "            7.9102e-02,  6.7188e-01],\n",
       "          [ 7.6562e-01, -3.0078e-01,  2.7148e-01,  ...,  1.8359e+00,\n",
       "           -1.7031e+00, -1.1953e+00]],\n",
       "\n",
       "         [[ 2.4414e-03,  1.9775e-02, -1.6113e-02,  ...,  1.2695e-01,\n",
       "            1.9409e-02,  2.4414e-01],\n",
       "          [ 4.6289e-01, -1.6094e+00, -5.7812e-01,  ...,  5.8984e-01,\n",
       "           -1.9165e-02,  9.9609e-01],\n",
       "          [-2.6172e-01, -1.5137e-01, -5.8984e-01,  ...,  8.8672e-01,\n",
       "            1.1406e+00,  3.2031e-01],\n",
       "          ...,\n",
       "          [-4.9414e-01,  7.2656e-01, -7.0312e-01,  ..., -2.9883e-01,\n",
       "            1.0234e+00, -9.6484e-01],\n",
       "          [-1.7969e-01, -1.3906e+00,  1.9062e+00,  ..., -1.9375e+00,\n",
       "           -4.5898e-01, -4.9805e-01],\n",
       "          [-6.2891e-01, -6.4062e-01,  1.2734e+00,  ..., -3.9453e-01,\n",
       "            3.7500e-01, -2.3828e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-9.5215e-03,  3.0151e-02, -3.9551e-02,  ...,  2.5024e-02,\n",
       "            2.3842e-04,  1.3672e-02],\n",
       "          [ 3.3203e-01,  1.1562e+00,  7.5000e-01,  ..., -9.6484e-01,\n",
       "            7.1289e-02, -1.0703e+00],\n",
       "          [ 6.9824e-02,  6.6406e-01,  4.5117e-01,  ..., -3.8477e-01,\n",
       "            4.5654e-02,  7.6172e-02],\n",
       "          ...,\n",
       "          [-8.0078e-01, -4.9414e-01,  5.0000e-01,  ..., -3.0469e-01,\n",
       "            8.8672e-01, -6.6797e-01],\n",
       "          [ 3.2812e-01,  2.7930e-01,  1.1484e+00,  ...,  7.3438e-01,\n",
       "            1.0938e+00,  4.2383e-01],\n",
       "          [-1.1719e+00,  6.5625e-01,  6.0938e-01,  ..., -1.6484e+00,\n",
       "            4.3164e-01, -7.5000e-01]],\n",
       "\n",
       "         [[-9.0942e-03,  5.9891e-04,  1.2756e-02,  ...,  1.6602e-02,\n",
       "           -2.3041e-03, -2.3438e-02],\n",
       "          [ 1.0234e+00, -8.0859e-01, -3.8086e-01,  ...,  3.5156e-01,\n",
       "           -1.6172e+00, -6.9531e-01],\n",
       "          [ 1.1641e+00, -1.1328e+00, -4.5898e-01,  ...,  6.0547e-01,\n",
       "           -1.4062e+00,  9.5215e-03],\n",
       "          ...,\n",
       "          [-2.1387e-01, -2.4902e-01,  1.1641e+00,  ...,  2.2559e-01,\n",
       "            1.0156e+00,  9.9219e-01],\n",
       "          [-1.0859e+00, -2.9492e-01,  8.9453e-01,  ...,  4.9805e-02,\n",
       "            1.1108e-02,  6.5625e-01],\n",
       "          [-4.7852e-01, -9.0234e-01, -6.2500e-01,  ...,  7.6294e-03,\n",
       "           -2.0605e-01,  3.2617e-01]],\n",
       "\n",
       "         [[-2.5482e-03,  1.5991e-02, -4.5410e-02,  ..., -2.0508e-02,\n",
       "            1.0803e-02,  2.0874e-02],\n",
       "          [-1.1250e+00,  1.1084e-01, -6.9531e-01,  ..., -6.9141e-01,\n",
       "           -2.7930e-01, -2.4023e-01],\n",
       "          [-5.3125e-01,  6.3281e-01, -4.9805e-01,  ..., -5.8984e-01,\n",
       "           -5.3516e-01, -3.0078e-01],\n",
       "          ...,\n",
       "          [ 9.8438e-01,  1.0000e+00, -1.9629e-01,  ..., -1.2422e+00,\n",
       "           -4.9023e-01, -8.7500e-01],\n",
       "          [ 7.4609e-01,  2.3242e-01, -1.0469e+00,  ..., -5.0781e-01,\n",
       "           -3.5156e-01,  2.6953e-01],\n",
       "          [ 9.4922e-01,  9.6680e-02, -7.5391e-01,  ...,  4.9219e-01,\n",
       "            6.1523e-02, -2.8516e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 9.9487e-03,  7.1335e-04,  1.9409e-02,  ..., -1.0925e-02,\n",
       "           -6.2866e-03,  1.7471e-03],\n",
       "          [ 5.8594e-01,  3.3789e-01,  3.6328e-01,  ...,  2.9297e-03,\n",
       "            3.2654e-03,  4.2578e-01],\n",
       "          [ 5.8203e-01,  3.0078e-01,  3.4766e-01,  ...,  1.2305e-01,\n",
       "            4.5703e-01,  1.1016e+00],\n",
       "          ...,\n",
       "          [-1.4221e-02,  5.3906e-01, -4.2188e-01,  ...,  7.2266e-01,\n",
       "           -1.0859e+00,  3.6719e-01],\n",
       "          [-3.0859e-01, -4.4727e-01, -1.2578e+00,  ...,  5.7617e-02,\n",
       "           -6.2109e-01,  6.9141e-01],\n",
       "          [-2.3828e-01, -4.9805e-01, -2.7344e-01,  ...,  4.2383e-01,\n",
       "            3.4375e-01,  7.2266e-01]],\n",
       "\n",
       "         [[-3.5858e-04,  1.6235e-02, -7.2937e-03,  ..., -1.2878e-02,\n",
       "            5.3711e-03,  1.0986e-02],\n",
       "          [-1.0156e+00, -2.7161e-03, -8.2812e-01,  ..., -4.1504e-02,\n",
       "            4.5117e-01,  4.1992e-01],\n",
       "          [-6.3281e-01, -8.8867e-02, -2.5586e-01,  ..., -1.0681e-04,\n",
       "           -5.2344e-01,  2.4121e-01],\n",
       "          ...,\n",
       "          [-7.7734e-01,  9.1406e-01, -1.5564e-02,  ..., -4.7070e-01,\n",
       "           -1.9434e-01, -5.2344e-01],\n",
       "          [-2.6172e-01,  5.1562e-01,  2.8125e-01,  ..., -4.1406e-01,\n",
       "            6.1719e-01,  3.2031e-01],\n",
       "          [ 5.4297e-01,  6.3281e-01,  2.1484e-01,  ..., -1.5234e+00,\n",
       "           -8.0566e-02, -1.1133e-01]],\n",
       "\n",
       "         [[ 1.3855e-02, -5.0659e-03,  1.2512e-02,  ..., -7.8125e-03,\n",
       "           -3.2471e-02, -1.4160e-02],\n",
       "          [ 2.0605e-01,  3.1445e-01, -1.5137e-01,  ..., -6.1035e-02,\n",
       "            8.2031e-01,  3.8477e-01],\n",
       "          [-1.7090e-01,  8.7402e-02,  1.7773e-01,  ...,  3.7891e-01,\n",
       "            1.9219e+00,  4.4922e-01],\n",
       "          ...,\n",
       "          [ 1.1406e+00,  9.1016e-01,  3.3008e-01,  ...,  1.5918e-01,\n",
       "           -1.0703e+00, -6.6016e-01],\n",
       "          [-1.9922e-01,  3.7109e-01, -4.1797e-01,  ...,  1.0156e+00,\n",
       "           -1.2969e+00, -1.0156e+00],\n",
       "          [ 1.1133e-01,  1.6875e+00,  6.9922e-01,  ..., -6.3281e-01,\n",
       "            2.5156e+00, -1.0078e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.0864e-02, -3.4180e-02,  4.5410e-02,  ..., -8.7109e-01,\n",
       "           -9.8633e-02,  3.9673e-03],\n",
       "          [ 1.4219e+00, -5.0781e-01, -4.9609e-01,  ...,  1.5625e+00,\n",
       "            9.3750e-01,  1.8750e-01],\n",
       "          [ 6.4844e-01, -2.7148e-01, -1.7734e+00,  ...,  1.9531e+00,\n",
       "            1.0469e+00,  4.5703e-01],\n",
       "          ...,\n",
       "          [-1.1328e-01, -3.7598e-02,  0.0000e+00,  ...,  1.2656e+00,\n",
       "            2.9419e-02, -3.8477e-01],\n",
       "          [-6.9141e-01, -6.1719e-01,  1.7969e-01,  ...,  1.4688e+00,\n",
       "           -7.3047e-01,  4.3945e-01],\n",
       "          [ 5.8594e-01,  1.8164e-01,  2.4219e-01,  ...,  1.7109e+00,\n",
       "           -1.1084e-01, -1.5430e-01]],\n",
       "\n",
       "         [[ 6.0303e-02, -1.0205e-01,  1.5747e-02,  ...,  5.9082e-02,\n",
       "           -3.1445e-01,  5.1953e-01],\n",
       "          [-9.1016e-01, -1.0352e-01,  1.3867e-01,  ..., -2.1094e+00,\n",
       "           -1.1250e+00,  2.2266e-01],\n",
       "          [-1.1094e+00,  4.1016e-01, -1.2207e-01,  ..., -2.7344e+00,\n",
       "           -7.8125e-01,  3.9648e-01],\n",
       "          ...,\n",
       "          [ 1.6953e+00,  5.5469e-01, -5.0781e-01,  ...,  4.2188e-01,\n",
       "           -9.5312e-01, -1.9766e+00],\n",
       "          [ 7.4707e-02,  1.3594e+00,  1.6016e-01,  ..., -2.1875e+00,\n",
       "            4.8633e-01, -2.2344e+00],\n",
       "          [-6.0938e-01, -2.1484e-02, -7.6172e-02,  ..., -1.4844e+00,\n",
       "           -1.9297e+00, -9.2969e-01]],\n",
       "\n",
       "         [[ 3.8818e-02,  1.8555e-02,  3.3691e-02,  ..., -8.4229e-03,\n",
       "            3.1055e-01, -1.6562e+00],\n",
       "          [-5.7422e-01, -1.1475e-01,  1.9989e-03,  ..., -1.3184e-01,\n",
       "           -1.2188e+00,  1.0703e+00],\n",
       "          [-3.6914e-01,  5.9082e-02,  4.1260e-02,  ..., -5.5469e-01,\n",
       "           -1.1172e+00,  1.1641e+00],\n",
       "          ...,\n",
       "          [ 5.6641e-01, -4.6875e-01,  6.2500e-01,  ..., -2.1406e+00,\n",
       "           -9.2773e-02,  1.2158e-01],\n",
       "          [-5.2344e-01, -4.6289e-01,  5.4297e-01,  ..., -1.2188e+00,\n",
       "           -3.1641e-01,  7.6172e-01],\n",
       "          [-8.5938e-01,  1.1816e-01, -2.6367e-01,  ..., -2.0156e+00,\n",
       "           -4.1602e-01,  5.7422e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 5.7129e-02, -3.1738e-02,  1.1230e-02,  ...,  8.2422e-01,\n",
       "           -7.3828e-01,  1.4844e+00],\n",
       "          [-5.5469e-01, -4.4141e-01,  2.9883e-01,  ...,  3.6133e-01,\n",
       "            2.2344e+00, -4.7656e-01],\n",
       "          [-4.8047e-01, -4.2578e-01, -1.7578e-01,  ..., -3.8867e-01,\n",
       "            2.0156e+00, -8.3594e-01],\n",
       "          ...,\n",
       "          [ 2.8125e-01, -3.9648e-01, -1.9727e-01,  ..., -3.9258e-01,\n",
       "            1.8672e+00, -1.0234e+00],\n",
       "          [ 4.6875e-02,  1.4609e+00,  9.9219e-01,  ...,  3.6621e-02,\n",
       "            1.5938e+00,  4.9805e-01],\n",
       "          [-8.9844e-01,  4.9414e-01,  6.0547e-01,  ..., -2.9883e-01,\n",
       "            1.7500e+00, -1.3750e+00]],\n",
       "\n",
       "         [[ 4.7363e-02,  3.4912e-02, -2.1606e-02,  ..., -4.3945e-02,\n",
       "           -5.8203e-01,  1.9141e-01],\n",
       "          [-2.4902e-01,  2.4219e-01, -2.0898e-01,  ...,  1.9336e-01,\n",
       "            4.6484e-01,  9.9609e-01],\n",
       "          [-6.6406e-01,  6.2988e-02,  7.0312e-02,  ...,  4.5312e-01,\n",
       "           -2.7710e-02,  1.0703e+00],\n",
       "          ...,\n",
       "          [ 5.7422e-01,  6.0938e-01, -5.1562e-01,  ...,  9.3750e-01,\n",
       "           -9.2188e-01, -7.3828e-01],\n",
       "          [-1.4941e-01, -1.7090e-01, -4.4141e-01,  ..., -5.5908e-02,\n",
       "           -4.6094e-01,  1.1953e+00],\n",
       "          [ 6.5625e-01, -9.7656e-04, -5.6641e-01,  ...,  1.5312e+00,\n",
       "            6.4453e-01,  1.0205e-01]],\n",
       "\n",
       "         [[ 4.1748e-02,  1.9531e-02, -3.2471e-02,  ..., -4.6484e-01,\n",
       "            7.1875e-01, -9.2163e-03],\n",
       "          [-3.1055e-01,  6.0547e-01,  3.1055e-01,  ..., -6.9531e-01,\n",
       "            1.7656e+00, -8.2031e-01],\n",
       "          [-1.3594e+00,  1.9043e-01,  1.5625e+00,  ...,  5.0391e-01,\n",
       "            2.3125e+00,  4.0234e-01],\n",
       "          ...,\n",
       "          [ 8.8281e-01, -6.3281e-01,  3.5547e-01,  ..., -9.6094e-01,\n",
       "           -1.0254e-01,  4.8828e-01],\n",
       "          [ 4.0625e-01,  1.1094e+00, -1.4941e-01,  ...,  4.9023e-01,\n",
       "           -2.0508e-01, -4.9023e-01],\n",
       "          [-4.5703e-01,  9.2578e-01, -3.4766e-01,  ...,  6.5918e-02,\n",
       "           -1.8359e-01,  6.3672e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-3.6377e-02, -1.4648e-02,  2.3804e-02,  ...,  5.6458e-03,\n",
       "            2.7954e-02, -3.8086e-02],\n",
       "          [ 8.2031e-02, -1.6022e-03, -2.1875e-01,  ...,  1.3477e-01,\n",
       "           -8.0859e-01, -8.2520e-02],\n",
       "          [-1.3965e-01,  5.3125e-01, -3.2422e-01,  ..., -3.7354e-02,\n",
       "           -4.3555e-01,  9.0332e-02],\n",
       "          ...,\n",
       "          [-3.7842e-02, -6.4062e-01,  9.5703e-01,  ...,  3.7109e-01,\n",
       "           -9.7656e-02,  7.3828e-01],\n",
       "          [ 4.2969e-01,  2.9688e-01,  5.6250e-01,  ...,  9.4922e-01,\n",
       "           -9.4922e-01,  1.5625e-01],\n",
       "          [-4.5508e-01, -2.8320e-01,  2.1729e-02,  ..., -4.2969e-01,\n",
       "           -6.9922e-01,  1.8066e-01]],\n",
       "\n",
       "         [[-3.4943e-03,  2.0264e-02,  3.8330e-02,  ...,  2.1240e-02,\n",
       "            7.5378e-03, -5.8594e-03],\n",
       "          [-8.8281e-01, -1.6309e-01,  8.0859e-01,  ..., -2.8320e-01,\n",
       "            2.2754e-01,  7.2656e-01],\n",
       "          [-1.0078e+00, -6.2109e-01,  1.0703e+00,  ..., -3.7305e-01,\n",
       "            7.3242e-02,  7.6953e-01],\n",
       "          ...,\n",
       "          [-3.3594e-01, -8.3984e-01,  9.2285e-02,  ..., -6.5625e-01,\n",
       "            2.2363e-01, -3.6523e-01],\n",
       "          [-1.5391e+00, -3.5938e-01, -5.9766e-01,  ..., -7.3438e-01,\n",
       "            6.6406e-01, -4.4727e-01],\n",
       "          [ 1.5723e-01, -5.7812e-01,  2.3730e-01,  ..., -3.9062e-01,\n",
       "            1.2598e-01,  8.2031e-01]],\n",
       "\n",
       "         [[-3.5156e-02,  4.4434e-02, -7.2021e-03,  ..., -5.2002e-02,\n",
       "           -7.7820e-04, -8.1787e-03],\n",
       "          [-3.0469e-01, -4.5703e-01, -3.0078e-01,  ..., -3.1445e-01,\n",
       "            1.2695e-01, -3.4375e-01],\n",
       "          [-8.2031e-02, -1.4062e-01, -3.9844e-01,  ..., -2.8711e-01,\n",
       "           -5.5420e-02, -5.2734e-01],\n",
       "          ...,\n",
       "          [ 5.6250e-01, -5.3711e-02, -1.5039e-01,  ..., -4.1260e-02,\n",
       "           -6.4941e-02, -1.4465e-02],\n",
       "          [-2.2949e-01, -6.7969e-01,  1.8164e-01,  ..., -7.8125e-02,\n",
       "           -5.1172e-01,  8.4961e-02],\n",
       "          [ 1.9629e-01, -4.3164e-01, -4.1211e-01,  ..., -1.5039e-01,\n",
       "           -2.8516e-01, -5.4443e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-5.4321e-03, -1.5747e-02,  2.5024e-03,  ..., -5.3406e-03,\n",
       "            2.4048e-02, -1.2207e-02],\n",
       "          [-1.3086e-01, -2.7734e-01, -4.1406e-01,  ...,  5.9326e-02,\n",
       "           -7.6562e-01, -8.9844e-01],\n",
       "          [-2.5000e-01,  7.5781e-01, -4.5471e-03,  ...,  3.4961e-01,\n",
       "           -7.5000e-01, -7.3438e-01],\n",
       "          ...,\n",
       "          [ 3.4961e-01,  1.7285e-01, -9.3359e-01,  ...,  4.4727e-01,\n",
       "            5.6250e-01,  8.7500e-01],\n",
       "          [-6.2988e-02,  4.7852e-01, -1.0234e+00,  ...,  8.3984e-01,\n",
       "            2.5586e-01,  1.1670e-01],\n",
       "          [-1.3086e-01, -2.7710e-02, -1.5234e+00,  ..., -2.2656e-01,\n",
       "            1.4844e-01, -4.1211e-01]],\n",
       "\n",
       "         [[ 6.8359e-03, -2.1851e-02, -1.6602e-02,  ..., -2.7954e-02,\n",
       "           -9.0332e-03, -2.6367e-02],\n",
       "          [-6.2891e-01, -1.6016e-01,  4.6387e-02,  ...,  1.3867e-01,\n",
       "           -6.8848e-02,  1.4551e-01],\n",
       "          [-3.3008e-01, -4.6484e-01, -6.9275e-03,  ..., -5.3906e-01,\n",
       "           -1.9727e-01, -2.9785e-02],\n",
       "          ...,\n",
       "          [ 1.9775e-02,  4.3555e-01,  5.0391e-01,  ..., -6.6016e-01,\n",
       "            6.4453e-02, -7.8516e-01],\n",
       "          [ 5.5176e-02, -4.2578e-01,  6.3672e-01,  ...,  7.7344e-01,\n",
       "            6.6406e-01, -6.8359e-01],\n",
       "          [ 1.6699e-01, -6.7578e-01,  2.4902e-01,  ...,  6.7871e-02,\n",
       "            2.0703e-01,  9.5703e-02]],\n",
       "\n",
       "         [[ 3.0365e-03, -1.9897e-02,  3.0396e-02,  ..., -5.7617e-02,\n",
       "            2.1240e-02, -5.5664e-02],\n",
       "          [ 1.1084e-01,  1.3672e-01,  1.4258e-01,  ...,  4.1797e-01,\n",
       "            8.9355e-02,  2.4316e-01],\n",
       "          [-6.6797e-01,  1.4453e-01,  4.4189e-02,  ...,  8.9453e-01,\n",
       "            2.6758e-01, -2.5000e-01],\n",
       "          ...,\n",
       "          [ 2.0752e-02,  5.2734e-01, -2.3315e-02,  ..., -2.6953e-01,\n",
       "            1.7188e-01,  2.2168e-01],\n",
       "          [ 5.1172e-01,  7.6562e-01, -1.4160e-01,  ..., -1.1377e-01,\n",
       "            3.8281e-01, -1.0449e-01],\n",
       "          [-1.7969e-01,  4.5703e-01, -2.1387e-01,  ..., -2.2852e-01,\n",
       "            8.5938e-01,  3.6523e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>))), hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\").to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CausalLMOutputWithPast(loss=None, logits=tensor([[[-12.5625,  -7.1250,  -0.6289,  ...,  -6.6562,  -7.9375,  -7.3125],\n",
       "         [-10.4375,  -6.7500,  -0.7852,  ...,  -7.0312,  -6.6562,  -8.2500],\n",
       "         [ -9.1250,  -6.1250,   2.4219,  ...,  -2.4844,  -2.9688,  -3.9375],\n",
       "         ...,\n",
       "         [-11.7500, -12.1250,   0.3535,  ...,  -7.4062,  -6.6875,  -7.3125],\n",
       "         [-10.1250, -10.2500,   2.0469,  ...,  -5.4062,  -5.2812,  -4.0312],\n",
       "         [ -3.1875,  -2.1875,  11.7500,  ...,   0.2109,  -0.8125,  -0.6953]]],\n",
       "       device='cuda:0', grad_fn=<ToCopyBackward0>), past_key_values=((tensor([[[[-4.5117e-01, -1.6968e-02,  5.1270e-02,  ...,  5.2246e-02,\n",
       "           -2.3804e-02, -1.2598e-01],\n",
       "          [-8.0469e-01,  4.6484e-01,  4.5898e-02,  ...,  4.0430e-01,\n",
       "           -1.2402e-01,  6.0547e-01],\n",
       "          [ 1.0859e+00, -3.9453e-01, -5.7031e-01,  ..., -3.2227e-01,\n",
       "            1.6113e-01, -3.6328e-01],\n",
       "          ...,\n",
       "          [-3.8867e-01,  3.6328e-01, -1.6504e-01,  ..., -2.1094e-01,\n",
       "            2.3145e-01, -1.6895e-01],\n",
       "          [-5.2490e-02,  7.4219e-02,  2.9297e-01,  ..., -9.6436e-03,\n",
       "            1.0059e-01,  1.1292e-02],\n",
       "          [-7.6172e-02, -3.1738e-02, -4.7607e-03,  ...,  1.3477e-01,\n",
       "            8.5449e-02,  1.7480e-01]],\n",
       "\n",
       "         [[ 1.3516e+00,  1.0781e+00, -4.2969e-01,  ...,  4.7656e-01,\n",
       "           -3.0469e-01,  4.7852e-01],\n",
       "          [ 7.4219e-01,  7.0312e-01, -3.3789e-01,  ..., -1.9629e-01,\n",
       "            3.4180e-01, -2.2656e-01],\n",
       "          [-1.5312e+00, -1.4062e+00,  3.3789e-01,  ..., -2.6562e-01,\n",
       "            2.9492e-01, -2.6953e-01],\n",
       "          ...,\n",
       "          [ 9.6094e-01,  1.9531e-01, -3.0469e-01,  ..., -4.6094e-01,\n",
       "            4.6631e-02, -3.4180e-01],\n",
       "          [ 7.4609e-01,  6.1719e-01, -3.7354e-02,  ..., -4.8633e-01,\n",
       "            1.4355e-01, -3.5742e-01],\n",
       "          [ 4.8828e-01,  4.6484e-01, -6.0156e-01,  ...,  7.3438e-01,\n",
       "           -1.6504e-01,  6.2109e-01]],\n",
       "\n",
       "         [[ 2.0142e-03, -2.5781e-01, -4.1602e-01,  ..., -1.2207e-01,\n",
       "            4.9805e-01,  7.6172e-01],\n",
       "          [-3.3789e-01, -1.0059e-01, -4.1016e-01,  ..., -4.7070e-01,\n",
       "           -5.1562e-01, -4.6680e-01],\n",
       "          [ 3.0273e-02, -2.4609e-01,  4.0527e-02,  ...,  1.4609e+00,\n",
       "            1.4922e+00,  1.3438e+00],\n",
       "          ...,\n",
       "          [-4.5117e-01, -2.8516e-01, -3.6133e-01,  ...,  1.7656e+00,\n",
       "            1.8047e+00,  1.6484e+00],\n",
       "          [-5.3906e-01, -2.6953e-01, -6.6895e-02,  ...,  1.6562e+00,\n",
       "            1.6719e+00,  1.5547e+00],\n",
       "          [-1.5625e-01, -3.3203e-01,  1.4258e-01,  ..., -9.8828e-01,\n",
       "           -1.1406e+00, -1.1016e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.4658e-02, -4.8523e-03,  3.9795e-02,  ...,  4.8633e-01,\n",
       "            9.6484e-01, -4.0234e-01],\n",
       "          [ 9.6680e-02,  7.3730e-02,  2.7954e-02,  ...,  1.8750e-01,\n",
       "            5.8984e-01, -1.1133e-01],\n",
       "          [ 1.1719e-02,  1.8047e+00,  4.1406e-01,  ..., -7.9297e-01,\n",
       "           -1.8203e+00, -1.5469e+00],\n",
       "          ...,\n",
       "          [ 1.5625e-02, -1.0703e+00,  7.4609e-01,  ..., -6.0938e-01,\n",
       "           -1.0391e+00, -1.3359e+00],\n",
       "          [-1.0156e+00, -7.8516e-01, -1.6309e-01,  ..., -6.3281e-01,\n",
       "           -8.9062e-01, -1.1172e+00],\n",
       "          [-6.1719e-01, -2.6562e-01, -3.2617e-01,  ..., -3.2422e-01,\n",
       "           -2.5781e-01, -6.9141e-01]],\n",
       "\n",
       "         [[ 1.7285e-01, -5.3906e-01, -3.1641e-01,  ..., -2.7344e-01,\n",
       "            1.3672e-01,  1.4062e-01],\n",
       "          [-1.0254e-01,  4.3555e-01, -8.5547e-01,  ...,  2.5391e-01,\n",
       "           -1.2256e-01, -1.1816e-01],\n",
       "          [ 2.7344e-01,  4.9805e-01, -1.0469e+00,  ..., -6.6797e-01,\n",
       "            3.0859e-01,  3.0859e-01],\n",
       "          ...,\n",
       "          [-5.2344e-01, -3.2812e-01,  6.8750e-01,  ..., -3.5547e-01,\n",
       "            9.2773e-02,  9.1797e-02],\n",
       "          [-4.7656e-01, -7.8125e-01,  9.1797e-01,  ..., -2.8906e-01,\n",
       "            8.1055e-02,  8.5449e-02],\n",
       "          [-4.3945e-02, -2.3730e-01, -7.0703e-01,  ...,  2.1680e-01,\n",
       "           -3.6865e-02, -3.6133e-02]],\n",
       "\n",
       "         [[-4.5117e-01,  1.2266e+00,  2.9144e-03,  ...,  7.1484e-01,\n",
       "            2.1582e-01,  2.7734e-01],\n",
       "          [-3.9648e-01, -9.0625e-01, -9.7266e-01,  ...,  1.2656e+00,\n",
       "           -6.4062e-01,  3.6328e-01],\n",
       "          [-1.8945e-01,  6.1523e-02,  5.5078e-01,  ..., -2.3750e+00,\n",
       "            1.1406e+00, -3.2031e-01],\n",
       "          ...,\n",
       "          [-2.6367e-02,  4.1748e-02, -1.4160e-02,  ..., -1.5703e+00,\n",
       "            9.8438e-01, -2.4902e-01],\n",
       "          [-1.6406e-01, -5.2246e-02, -3.3691e-02,  ..., -8.3594e-01,\n",
       "            4.7070e-01, -1.2207e-01],\n",
       "          [ 6.0547e-02, -4.1602e-01, -6.6406e-01,  ...,  8.2812e-01,\n",
       "           -4.0430e-01, -8.5449e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-6.0425e-03, -6.3782e-03,  5.5847e-03,  ...,  1.3809e-03,\n",
       "            2.0264e-02, -5.5237e-03],\n",
       "          [ 3.8719e-04, -8.7891e-03, -3.2806e-03,  ...,  5.3406e-03,\n",
       "            1.2131e-03, -6.1646e-03],\n",
       "          [-5.3406e-03,  3.0823e-03, -2.9297e-03,  ...,  5.1022e-05,\n",
       "           -2.4719e-03,  6.9275e-03],\n",
       "          ...,\n",
       "          [ 3.1586e-03,  1.7166e-03,  6.1035e-04,  ...,  6.3782e-03,\n",
       "            2.3193e-03, -9.2506e-05],\n",
       "          [ 4.0894e-03, -2.8534e-03,  3.5400e-03,  ..., -2.3746e-04,\n",
       "            3.7079e-03, -8.7280e-03],\n",
       "          [-2.2736e-03,  6.0425e-03,  7.2632e-03,  ...,  9.8419e-04,\n",
       "            8.9111e-03, -7.1716e-03]],\n",
       "\n",
       "         [[ 2.0599e-03,  4.0283e-03, -9.7046e-03,  ..., -4.7607e-03,\n",
       "            6.9275e-03, -1.9043e-02],\n",
       "          [ 8.1539e-05, -9.0790e-04,  2.3956e-03,  ...,  9.7046e-03,\n",
       "            2.1057e-03, -3.3417e-03],\n",
       "          [ 4.5204e-04,  2.8229e-04, -8.2397e-04,  ...,  2.4109e-03,\n",
       "            1.6327e-03,  9.3079e-04],\n",
       "          ...,\n",
       "          [-5.4321e-03,  1.8082e-03, -6.9275e-03,  ...,  1.4420e-03,\n",
       "           -4.1504e-03,  2.6131e-04],\n",
       "          [-4.4556e-03,  2.1973e-03,  2.9907e-03,  ..., -1.8768e-03,\n",
       "           -6.1035e-03, -1.8997e-03],\n",
       "          [ 2.1973e-03, -4.6387e-03,  5.6458e-04,  ...,  2.0905e-03,\n",
       "           -3.4637e-03,  8.3008e-03]],\n",
       "\n",
       "         [[-3.8605e-03, -6.4392e-03,  1.4221e-02,  ...,  5.0354e-03,\n",
       "           -1.4648e-02,  6.9885e-03],\n",
       "          [ 7.9956e-03, -6.7139e-03, -6.6757e-04,  ...,  5.0354e-03,\n",
       "            2.9449e-03, -3.7231e-03],\n",
       "          [-1.1253e-04,  2.4109e-03,  4.4632e-04,  ...,  6.0272e-04,\n",
       "            5.4321e-03, -4.6997e-03],\n",
       "          ...,\n",
       "          [-2.6131e-04,  1.1444e-03, -3.2663e-05,  ...,  1.2112e-04,\n",
       "            5.2490e-03, -1.1597e-03],\n",
       "          [ 1.3428e-03, -3.3112e-03, -8.0109e-04,  ..., -7.9727e-04,\n",
       "           -4.8447e-04,  6.1035e-03],\n",
       "          [ 1.9836e-03,  1.1749e-03, -2.5635e-03,  ..., -1.4114e-03,\n",
       "           -7.9346e-03,  3.8528e-04]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-5.9509e-04, -2.0752e-03, -1.5747e-02,  ...,  2.0752e-02,\n",
       "            9.4604e-03, -5.1880e-03],\n",
       "          [-1.0376e-02,  4.0283e-02, -6.9336e-02,  ..., -9.7656e-02,\n",
       "           -3.1494e-02, -2.0752e-02],\n",
       "          [-1.1536e-02, -1.7822e-02, -1.2024e-02,  ...,  6.0120e-03,\n",
       "           -9.0820e-02,  2.4414e-02],\n",
       "          ...,\n",
       "          [-8.8501e-03, -2.4414e-02, -3.0884e-02,  ...,  2.1606e-02,\n",
       "           -4.6387e-03, -2.1729e-02],\n",
       "          [ 2.2827e-02, -1.9287e-02,  3.4912e-02,  ...,  3.1738e-02,\n",
       "           -3.1738e-02,  3.9795e-02],\n",
       "          [ 1.0071e-02, -4.4922e-02, -7.8735e-03,  ..., -1.7456e-02,\n",
       "            5.6641e-02,  1.8433e-02]],\n",
       "\n",
       "         [[ 4.3106e-04,  1.5503e-02, -1.0254e-02,  ..., -8.6060e-03,\n",
       "           -4.5166e-02, -2.5024e-02],\n",
       "          [-6.5918e-03,  2.0386e-02, -6.9885e-03,  ..., -6.9580e-03,\n",
       "            7.9632e-05, -4.5776e-03],\n",
       "          [ 2.1820e-03,  2.2583e-03,  3.4027e-03,  ..., -2.4128e-04,\n",
       "           -5.9509e-03,  4.4250e-04],\n",
       "          ...,\n",
       "          [ 1.7853e-03,  3.9062e-03, -7.5150e-04,  ..., -2.8610e-04,\n",
       "            5.4626e-03,  1.2207e-03],\n",
       "          [ 3.6316e-03,  4.6997e-03,  1.2207e-03,  ...,  5.6458e-03,\n",
       "            4.0894e-03, -1.0376e-02],\n",
       "          [ 8.3008e-03, -1.8477e-06, -2.0294e-03,  ...,  3.8147e-03,\n",
       "           -5.3101e-03,  1.6479e-03]],\n",
       "\n",
       "         [[-3.9978e-03, -3.9368e-03,  4.6730e-04,  ...,  1.4572e-03,\n",
       "           -6.7139e-04, -1.4420e-03],\n",
       "          [ 2.9144e-03,  1.0498e-02, -9.3384e-03,  ...,  1.4191e-03,\n",
       "           -9.7656e-03,  4.1809e-03],\n",
       "          [-1.6327e-03,  1.3809e-03,  2.3499e-03,  ..., -6.0654e-04,\n",
       "           -1.9226e-03,  1.8234e-03],\n",
       "          ...,\n",
       "          [ 2.9449e-03, -1.0529e-03, -6.4697e-03,  ..., -3.0136e-04,\n",
       "            5.5847e-03,  3.4790e-03],\n",
       "          [ 3.3569e-03, -1.3794e-02,  8.6670e-03,  ..., -4.6730e-04,\n",
       "            1.4191e-03,  2.9449e-03],\n",
       "          [-7.2632e-03,  2.6093e-03, -1.9989e-03,  ...,  1.4709e-02,\n",
       "           -5.5237e-03, -1.4648e-03]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-0.3184,  0.7617,  0.5352,  ...,  0.1494,  0.1118, -0.2021],\n",
       "          [-1.7656,  0.1016, -0.6992,  ..., -0.8164, -0.5859,  0.2930],\n",
       "          [-0.7148,  0.0466, -1.2734,  ..., -0.5312, -0.3438, -0.0903],\n",
       "          ...,\n",
       "          [ 1.4375,  0.0913, -0.1094,  ..., -0.9375, -0.9219,  0.2139],\n",
       "          [-0.5703,  0.1055,  0.6055,  ..., -0.8320, -1.0938,  0.6172],\n",
       "          [-1.1953,  0.3926,  1.2266,  ..., -0.4551, -1.2344,  0.8438]],\n",
       "\n",
       "         [[-0.1016, -0.3770,  0.0405,  ...,  0.6875,  0.6758,  0.6680],\n",
       "          [ 0.1514, -0.1484, -1.0156,  ..., -1.1328, -0.4121, -0.3633],\n",
       "          [ 0.1270,  0.3047, -0.1865,  ..., -0.2217,  0.2178,  0.3594],\n",
       "          ...,\n",
       "          [-0.5312,  0.0752,  0.7539,  ..., -0.8516,  0.0728, -0.2168],\n",
       "          [ 0.1445, -0.7422,  1.0000,  ..., -0.8984, -0.3438, -0.5273],\n",
       "          [ 1.3438, -0.8086, -0.2617,  ..., -0.7734, -0.4062, -0.5664]],\n",
       "\n",
       "         [[-0.0544,  0.4590,  0.4805,  ...,  1.1875,  0.7227,  0.1001],\n",
       "          [-0.1074, -0.2793, -0.2402,  ...,  0.8477,  0.3984,  0.2715],\n",
       "          [-0.0762, -0.2637,  0.0835,  ...,  1.2344,  0.6836,  0.4141],\n",
       "          ...,\n",
       "          [ 0.3574,  0.3633,  0.2168,  ...,  1.1016,  0.4883,  0.5664],\n",
       "          [ 0.5195,  0.1377,  0.0713,  ...,  1.0781,  0.5352,  0.4629],\n",
       "          [ 0.3340, -0.0474, -0.0713,  ...,  1.0000,  0.3906,  0.5117]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-0.7070,  0.0168,  0.0400,  ...,  0.9219, -0.3340, -0.0630],\n",
       "          [-0.1504, -0.1182, -0.1670,  ..., -0.0684, -0.1631,  0.0532],\n",
       "          [ 0.6133, -0.1455,  0.1533,  ...,  0.5820, -0.2773,  0.2314],\n",
       "          ...,\n",
       "          [-0.6016,  0.0640, -0.1777,  ..., -0.2539, -0.0356,  0.1396],\n",
       "          [-0.0059, -0.0669, -0.4492,  ..., -0.0022, -0.1309,  0.1650],\n",
       "          [ 0.3105, -0.1611, -0.3008,  ..., -0.0208, -0.0027,  0.0928]],\n",
       "\n",
       "         [[ 0.2217,  0.2490, -0.1631,  ..., -0.4531,  0.7031, -0.0097],\n",
       "          [-0.0352, -0.4199,  0.0034,  ..., -0.1973, -0.0830, -0.3555],\n",
       "          [-0.2402, -0.1484,  0.1348,  ..., -0.6094,  0.6016,  0.0085],\n",
       "          ...,\n",
       "          [ 0.4043, -0.0337, -0.0308,  ...,  0.1001, -0.2441, -0.0776],\n",
       "          [ 0.5742, -0.1943,  0.4062,  ..., -0.1050, -0.3008,  0.3672],\n",
       "          [-0.1914, -0.4844,  0.7383,  ...,  0.1279, -0.5117,  0.1001]],\n",
       "\n",
       "         [[-0.0182,  0.0591,  0.0898,  ...,  0.2012, -0.2031,  0.1094],\n",
       "          [-0.4707,  0.4551,  0.4414,  ...,  1.0781, -1.1250,  0.7344],\n",
       "          [ 0.1367,  0.0889,  0.2598,  ...,  0.7383, -0.7891,  0.4434],\n",
       "          ...,\n",
       "          [-0.3945, -0.3984, -0.2168,  ...,  1.2188, -1.2422,  0.8555],\n",
       "          [-0.8945, -0.2656, -0.3184,  ...,  1.3125, -1.3516,  0.9492],\n",
       "          [-0.7891,  0.0869, -0.0615,  ...,  1.4062, -1.4141,  1.0078]]]],\n",
       "       device='cuda:0', dtype=torch.bfloat16, grad_fn=<AddBackward0>), tensor([[[[ 8.9355e-02, -4.1199e-03, -3.0518e-02,  ...,  1.6846e-02,\n",
       "           -1.8311e-02,  6.5430e-02],\n",
       "          [-1.0498e-02,  1.1084e-01,  4.8584e-02,  ...,  3.7598e-02,\n",
       "           -1.4465e-02, -6.3965e-02],\n",
       "          [ 8.1543e-02,  6.9580e-03,  1.1658e-02,  ...,  1.1902e-02,\n",
       "            3.9482e-04, -1.5869e-02],\n",
       "          ...,\n",
       "          [-1.1426e-01,  1.3574e-01, -6.8970e-03,  ...,  2.9755e-03,\n",
       "           -4.7852e-02,  3.4180e-02],\n",
       "          [-1.3672e-01,  6.0547e-02,  2.9541e-02,  ..., -1.3916e-02,\n",
       "            1.0376e-02, -2.7710e-02],\n",
       "          [-4.1992e-02, -7.1106e-03, -2.5024e-02,  ..., -3.2715e-02,\n",
       "            3.6133e-02,  7.5684e-03]],\n",
       "\n",
       "         [[-9.8419e-04,  7.4005e-04, -9.0942e-03,  ...,  2.6550e-03,\n",
       "            8.7280e-03,  7.8735e-03],\n",
       "          [ 1.0254e-02, -1.5991e-02, -8.3618e-03,  ...,  1.5015e-02,\n",
       "           -2.7466e-03, -1.0742e-02],\n",
       "          [ 9.1553e-03,  4.3335e-03,  3.2501e-03,  ...,  7.1106e-03,\n",
       "            5.0659e-03, -7.2327e-03],\n",
       "          ...,\n",
       "          [ 5.6458e-03, -9.4604e-03, -7.2479e-04,  ...,  5.6458e-04,\n",
       "            1.0193e-02,  6.6833e-03],\n",
       "          [ 4.7607e-03, -1.1292e-02,  2.0409e-04,  ..., -6.2561e-03,\n",
       "            1.0376e-02,  6.6833e-03],\n",
       "          [-2.7466e-02, -1.4526e-02,  8.7280e-03,  ...,  1.8188e-02,\n",
       "            1.1536e-02, -2.1820e-03]],\n",
       "\n",
       "         [[-3.2471e-02, -1.3428e-02, -2.3804e-02,  ..., -8.9355e-02,\n",
       "           -9.5215e-02,  6.8359e-02],\n",
       "          [ 4.4189e-02,  1.1658e-02,  8.3984e-02,  ...,  1.2598e-01,\n",
       "            7.2754e-02, -1.1328e-01],\n",
       "          [-3.0396e-02,  1.4648e-02, -3.5889e-02,  ..., -5.0049e-02,\n",
       "           -4.3457e-02,  6.6406e-02],\n",
       "          ...,\n",
       "          [ 9.3750e-02,  5.4016e-03,  2.5146e-02,  ..., -7.9102e-02,\n",
       "            1.9531e-01, -5.1270e-02],\n",
       "          [ 3.9307e-02,  3.1128e-02,  1.5747e-02,  ...,  9.4727e-02,\n",
       "           -3.6377e-02, -2.9102e-01],\n",
       "          [-1.1597e-02,  1.5747e-02, -4.2725e-03,  ..., -1.5039e-01,\n",
       "           -7.5684e-02,  2.6367e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.9053e-02,  1.9287e-02, -1.3542e-04,  ..., -3.5400e-02,\n",
       "           -1.2695e-02, -2.0142e-02],\n",
       "          [-1.7212e-02, -4.9133e-03,  1.9897e-02,  ...,  9.0942e-03,\n",
       "           -1.7700e-02,  1.4099e-02],\n",
       "          [-2.2278e-03, -1.1444e-03,  4.5776e-03,  ...,  2.5749e-04,\n",
       "           -1.6479e-03, -1.2390e-02],\n",
       "          ...,\n",
       "          [ 2.2888e-03,  1.4191e-03,  2.4658e-02,  ..., -1.6602e-02,\n",
       "           -1.5869e-02,  4.7913e-03],\n",
       "          [-1.0315e-02, -1.0864e-02,  3.2471e-02,  ..., -2.0142e-03,\n",
       "           -1.3062e-02,  1.6113e-02],\n",
       "          [ 1.3916e-02, -8.2397e-03,  1.9836e-03,  ..., -1.0315e-02,\n",
       "            5.6641e-02, -5.0659e-03]],\n",
       "\n",
       "         [[-1.5234e-01, -2.0142e-02, -1.3379e-01,  ..., -2.8076e-03,\n",
       "            5.5542e-03, -8.9111e-03],\n",
       "          [-8.3008e-02, -6.7139e-03,  2.5195e-01,  ..., -1.4404e-02,\n",
       "            1.3123e-02,  2.9297e-03],\n",
       "          [-1.5137e-02,  2.3926e-02, -2.0020e-02,  ...,  2.2316e-04,\n",
       "           -5.8289e-03, -2.2339e-02],\n",
       "          ...,\n",
       "          [-3.3203e-02, -1.0645e-01, -4.0283e-02,  ...,  6.1646e-03,\n",
       "            4.0588e-03, -6.8970e-03],\n",
       "          [-2.5757e-02, -1.8463e-03,  1.8262e-01,  ..., -1.8311e-02,\n",
       "            3.7354e-02,  1.1719e-02],\n",
       "          [-2.1191e-01,  2.3315e-02,  6.9336e-02,  ..., -8.7280e-03,\n",
       "           -4.2114e-03,  2.2583e-03]],\n",
       "\n",
       "         [[-3.9368e-03,  1.2024e-02,  1.1658e-02,  ..., -2.2430e-03,\n",
       "            8.4229e-03,  4.3640e-03],\n",
       "          [ 2.9907e-02,  1.1230e-02, -6.7749e-03,  ...,  2.4719e-03,\n",
       "           -1.0742e-02,  2.3682e-02],\n",
       "          [-5.7678e-03, -4.9744e-03,  6.0120e-03,  ...,  1.8501e-04,\n",
       "           -3.3112e-03, -4.1199e-03],\n",
       "          ...,\n",
       "          [-2.0294e-03,  1.4709e-02,  1.1292e-02,  ..., -7.2327e-03,\n",
       "            5.0659e-03,  8.3618e-03],\n",
       "          [-7.4158e-03, -6.3171e-03,  1.2817e-02,  ...,  1.0010e-02,\n",
       "            4.2419e-03,  3.9062e-03],\n",
       "          [ 3.0670e-03,  2.4414e-02, -7.3242e-03,  ..., -7.7209e-03,\n",
       "           -2.4414e-02,  3.0518e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 7.9346e-03,  2.7710e-02,  1.5564e-02,  ..., -1.8750e-01,\n",
       "           -1.1768e-01,  1.4746e-01],\n",
       "          [-1.1250e+00, -1.2988e-01,  4.8828e-01,  ...,  6.9922e-01,\n",
       "            3.5742e-01, -3.3594e-01],\n",
       "          [-1.2266e+00,  0.0000e+00,  3.5352e-01,  ...,  2.4414e-01,\n",
       "           -7.9102e-02, -6.6016e-01],\n",
       "          ...,\n",
       "          [ 1.6562e+00,  3.4766e-01,  1.2402e-01,  ..., -5.7031e-01,\n",
       "            2.8125e-01, -2.1484e-01],\n",
       "          [ 1.1875e+00, -1.0010e-01, -1.5234e-01,  ..., -8.5938e-01,\n",
       "            1.2031e+00, -8.3203e-01],\n",
       "          [ 0.0000e+00, -4.4531e-01, -5.9375e-01,  ...,  6.6797e-01,\n",
       "            8.3594e-01, -1.0625e+00]],\n",
       "\n",
       "         [[-1.9653e-02,  2.4048e-02,  5.3223e-02,  ...,  2.4902e-02,\n",
       "            3.3398e-01, -3.4570e-01],\n",
       "          [ 3.2031e-01,  0.0000e+00,  9.2773e-02,  ..., -3.1250e-01,\n",
       "           -1.6016e+00,  1.6641e+00],\n",
       "          [-1.8164e-01, -1.5039e-01, -2.8125e-01,  ...,  3.2422e-01,\n",
       "           -2.1562e+00,  2.3438e+00],\n",
       "          ...,\n",
       "          [ 2.1875e-01,  6.1279e-02,  3.4375e-01,  ..., -8.3594e-01,\n",
       "           -2.1406e+00,  2.3906e+00],\n",
       "          [-9.2285e-02, -6.8359e-02, -1.4844e-01,  ..., -3.0078e-01,\n",
       "           -1.9062e+00,  1.8672e+00],\n",
       "          [ 4.4922e-01,  7.9297e-01, -5.2344e-01,  ..., -3.8281e-01,\n",
       "           -1.6953e+00,  1.9453e+00]],\n",
       "\n",
       "         [[-3.1128e-02,  7.4707e-02,  2.5024e-02,  ..., -8.6914e-02,\n",
       "            3.4180e-01, -3.5156e-01],\n",
       "          [ 3.0469e+00, -2.1719e+00,  4.0234e-01,  ...,  1.6953e+00,\n",
       "           -1.7734e+00,  1.7656e+00],\n",
       "          [ 1.1094e+00, -1.2969e+00,  1.3477e-01,  ...,  1.2734e+00,\n",
       "           -1.9297e+00,  2.6875e+00],\n",
       "          ...,\n",
       "          [-1.9219e+00,  1.0625e+00, -8.9844e-01,  ...,  3.2031e+00,\n",
       "           -2.6562e+00,  2.4844e+00],\n",
       "          [ 1.5625e-01,  1.2734e+00,  9.7656e-04,  ...,  1.2969e+00,\n",
       "           -9.6875e-01,  2.0312e+00],\n",
       "          [ 1.9062e+00,  6.5430e-02, -7.9590e-02,  ..., -1.2207e-01,\n",
       "           -3.0469e+00,  2.0469e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.3936e-02,  4.3640e-03,  2.0020e-02,  ...,  3.5889e-02,\n",
       "            3.7598e-02,  9.5215e-02],\n",
       "          [ 3.7500e-01,  1.2500e-01,  8.9355e-02,  ...,  2.4219e-01,\n",
       "           -5.7422e-01, -2.6758e-01],\n",
       "          [ 1.5430e-01, -4.6289e-01,  2.6758e-01,  ...,  3.8477e-01,\n",
       "           -4.7266e-01,  8.6914e-02],\n",
       "          ...,\n",
       "          [-4.3359e-01, -1.2207e-02,  1.2256e-01,  ..., -2.3145e-01,\n",
       "           -1.7773e-01, -1.6113e-02],\n",
       "          [-5.8203e-01,  1.0742e-01, -9.1309e-02,  ..., -1.1377e-01,\n",
       "            1.5625e-01, -8.8281e-01],\n",
       "          [-1.1328e-01, -1.8555e-02,  8.1787e-03,  ...,  3.6523e-01,\n",
       "           -1.5918e-01, -7.2266e-01]],\n",
       "\n",
       "         [[-7.9346e-03,  3.4424e-02,  9.8267e-03,  ..., -7.2754e-02,\n",
       "           -6.4941e-02,  3.4180e-03],\n",
       "          [-1.2969e+00,  6.2109e-01, -1.6797e-01,  ...,  3.3203e-01,\n",
       "           -4.6680e-01, -5.9375e-01],\n",
       "          [-1.4062e+00,  1.4453e+00,  7.8125e-02,  ...,  2.4414e-01,\n",
       "            7.4219e-01, -1.0781e+00],\n",
       "          ...,\n",
       "          [ 2.0156e+00, -6.2500e-01,  3.2422e-01,  ..., -2.8125e-01,\n",
       "           -3.1250e-01, -4.5703e-01],\n",
       "          [ 1.3438e+00, -2.2344e+00, -1.0938e+00,  ...,  2.5625e+00,\n",
       "            1.7031e+00,  1.9141e+00],\n",
       "          [-5.6641e-01, -1.8984e+00, -7.1484e-01,  ...,  7.3047e-01,\n",
       "            1.6328e+00, -9.7656e-02]],\n",
       "\n",
       "         [[-4.9316e-02, -2.2339e-02, -1.8311e-02,  ...,  1.0791e-01,\n",
       "            4.7266e-01,  8.6426e-02],\n",
       "          [ 1.0781e+00, -4.8438e-01, -1.0859e+00,  ..., -4.1016e-01,\n",
       "           -2.1094e+00, -2.5586e-01],\n",
       "          [-5.3906e-01, -1.1797e+00, -5.3125e-01,  ...,  3.4375e-01,\n",
       "           -2.5469e+00,  8.0566e-02],\n",
       "          ...,\n",
       "          [ 5.1562e-01,  3.9648e-01,  1.0391e+00,  ..., -1.3984e+00,\n",
       "           -2.7344e+00,  1.0156e+00],\n",
       "          [ 1.6250e+00,  6.2891e-01,  8.7500e-01,  ..., -1.6016e+00,\n",
       "           -2.5625e+00, -6.5234e-01],\n",
       "          [ 1.4531e+00,  4.7656e-01,  3.9062e-01,  ..., -7.6172e-01,\n",
       "           -2.2500e+00, -1.0859e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 3.8300e-03, -9.2316e-04,  4.5776e-03,  ...,  1.5163e-04,\n",
       "           -9.9182e-05, -1.3504e-03],\n",
       "          [-2.2949e-02,  1.5527e-01,  2.0410e-01,  ...,  7.1289e-02,\n",
       "            3.8330e-02, -5.1514e-02],\n",
       "          [-9.6680e-02, -1.1169e-02, -1.5527e-01,  ..., -5.1025e-02,\n",
       "            1.8066e-02, -2.3438e-02],\n",
       "          ...,\n",
       "          [ 4.0039e-02,  8.0566e-02,  1.3855e-02,  ...,  1.1621e-01,\n",
       "           -7.8125e-02, -2.1680e-01],\n",
       "          [-9.6191e-02, -6.1719e-01, -1.2207e-01,  ...,  3.2715e-02,\n",
       "           -1.1670e-01,  9.7168e-02],\n",
       "          [-1.6602e-02, -1.4062e-01,  2.3828e-01,  ...,  1.9043e-01,\n",
       "            3.2812e-01, -3.8672e-01]],\n",
       "\n",
       "         [[ 2.7832e-02,  7.0190e-04,  3.7384e-03,  ..., -9.2316e-04,\n",
       "           -3.1433e-03,  1.6708e-03],\n",
       "          [-9.8047e-01, -1.5137e-01, -1.6602e-01,  ..., -1.9409e-02,\n",
       "           -1.6699e-01,  1.0059e-01],\n",
       "          [-3.5742e-01, -9.7168e-02, -2.5391e-01,  ..., -2.4414e-01,\n",
       "           -6.5430e-02, -4.1748e-02],\n",
       "          ...,\n",
       "          [-2.5977e-01,  6.2012e-02, -6.0059e-02,  ..., -5.5664e-02,\n",
       "            1.0400e-01, -3.5400e-02],\n",
       "          [-7.6562e-01, -2.6953e-01, -2.6172e-01,  ...,  9.9609e-02,\n",
       "           -4.9316e-02, -1.4062e-01],\n",
       "          [-6.8359e-01,  4.6387e-02, -1.6406e-01,  ..., -2.8906e-01,\n",
       "           -6.5918e-02,  9.2285e-02]],\n",
       "\n",
       "         [[-6.3419e-05, -3.0060e-03,  5.8899e-03,  ...,  1.7776e-03,\n",
       "           -1.1597e-03, -1.2283e-03],\n",
       "          [ 1.5723e-01,  2.0703e-01, -9.7656e-02,  ..., -3.0273e-02,\n",
       "            3.2227e-02, -7.8613e-02],\n",
       "          [-7.3730e-02,  1.2109e-01,  1.3672e-01,  ...,  4.1199e-03,\n",
       "            1.1523e-01,  1.0596e-01],\n",
       "          ...,\n",
       "          [-2.1289e-01,  1.4343e-02,  9.1797e-02,  ..., -8.2031e-02,\n",
       "            2.4658e-02,  3.9551e-02],\n",
       "          [ 1.2061e-01,  4.0283e-02,  9.6680e-02,  ...,  6.0791e-02,\n",
       "            1.1963e-01,  2.4658e-02],\n",
       "          [ 2.4707e-01, -4.9316e-02, -2.7930e-01,  ..., -6.9336e-02,\n",
       "            9.5703e-02, -1.7090e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 4.4556e-03,  2.0752e-02, -4.8218e-03,  ..., -5.1880e-03,\n",
       "            1.8597e-04,  4.6692e-03],\n",
       "          [ 1.0742e-01, -2.3145e-01,  2.3633e-01,  ..., -1.0864e-02,\n",
       "           -1.5137e-01,  5.5908e-02],\n",
       "          [-6.5430e-02, -5.1270e-03,  1.6504e-01,  ..., -4.9561e-02,\n",
       "           -3.1836e-01,  6.2012e-02],\n",
       "          ...,\n",
       "          [-2.4902e-02,  1.8848e-01,  3.6865e-02,  ..., -1.9043e-01,\n",
       "           -2.1973e-02, -4.0283e-02],\n",
       "          [-3.8086e-02,  1.5430e-01,  1.3477e-01,  ..., -4.5898e-01,\n",
       "           -4.9219e-01, -6.8359e-02],\n",
       "          [ 1.6504e-01, -6.9922e-01, -1.4746e-01,  ...,  7.2266e-01,\n",
       "           -1.2354e-01,  1.0059e-01]],\n",
       "\n",
       "         [[ 4.0894e-03, -9.6436e-03,  2.0599e-03,  ...,  4.8828e-04,\n",
       "           -1.5182e-03, -2.3193e-03],\n",
       "          [-2.9297e-02,  2.0264e-02, -9.0942e-03,  ...,  1.3281e-01,\n",
       "           -1.9409e-02,  2.0605e-01],\n",
       "          [ 3.0151e-02, -3.8281e-01,  1.2283e-03,  ..., -2.0312e-01,\n",
       "           -2.5781e-01, -1.8311e-02],\n",
       "          ...,\n",
       "          [-5.3223e-02, -4.3213e-02,  8.1787e-03,  ..., -1.0300e-03,\n",
       "           -8.6426e-02, -1.0132e-02],\n",
       "          [ 1.4941e-01, -4.3359e-01,  2.4109e-03,  ..., -3.3203e-01,\n",
       "           -1.6895e-01, -4.2480e-02],\n",
       "          [ 1.6406e-01, -3.0664e-01,  1.5918e-01,  ...,  4.4189e-02,\n",
       "           -2.7148e-01, -1.9629e-01]],\n",
       "\n",
       "         [[ 2.5177e-03, -3.4637e-03, -8.0585e-05,  ...,  1.8311e-02,\n",
       "           -6.5613e-04, -2.1820e-03],\n",
       "          [ 3.7500e-01,  1.3086e-01, -7.0801e-02,  ..., -3.8477e-01,\n",
       "           -2.9297e-01,  2.8711e-01],\n",
       "          [-4.5898e-02,  9.1309e-02, -2.6172e-01,  ..., -5.4932e-02,\n",
       "           -1.3379e-01,  4.5654e-02],\n",
       "          ...,\n",
       "          [ 1.1230e-01,  1.2891e-01,  3.3984e-01,  ..., -2.7539e-01,\n",
       "           -6.4941e-02,  2.1680e-01],\n",
       "          [-2.5879e-02,  1.3086e-01, -2.9297e-01,  ..., -9.0332e-02,\n",
       "           -5.2734e-02, -9.9121e-02],\n",
       "          [-1.3770e-01,  9.5215e-03,  3.3936e-02,  ..., -9.0820e-02,\n",
       "            1.3770e-01, -1.1035e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.0681e-02, -1.8845e-03,  1.2451e-02,  ...,  5.5859e-01,\n",
       "            5.5176e-02, -5.5469e-01],\n",
       "          [-8.5156e-01,  4.8438e-01,  1.5820e-01,  ..., -2.2344e+00,\n",
       "           -1.1094e+00,  2.4688e+00],\n",
       "          [ 2.4316e-01, -4.5508e-01,  1.8164e-01,  ..., -2.3594e+00,\n",
       "           -5.9375e-01,  2.4688e+00],\n",
       "          ...,\n",
       "          [-4.1797e-01,  7.1875e-01,  4.3164e-01,  ..., -2.7812e+00,\n",
       "           -9.3750e-01,  2.8906e+00],\n",
       "          [-4.7070e-01,  1.4844e-01,  2.9297e-01,  ..., -3.2969e+00,\n",
       "            3.1406e+00,  3.4531e+00],\n",
       "          [-1.6016e+00,  5.7422e-01,  4.4141e-01,  ..., -2.9062e+00,\n",
       "           -2.2188e+00,  3.0156e+00]],\n",
       "\n",
       "         [[ 2.2736e-03,  3.4180e-02, -1.7822e-02,  ..., -4.2383e-01,\n",
       "           -4.5703e-01, -3.6523e-01],\n",
       "          [ 2.8125e-01, -1.2988e-01, -9.4727e-02,  ...,  8.9062e-01,\n",
       "            1.2578e+00,  1.2031e+00],\n",
       "          [-5.4199e-02, -7.5195e-02,  3.8574e-02,  ...,  6.4844e-01,\n",
       "            1.5938e+00,  4.8828e-01],\n",
       "          ...,\n",
       "          [ 2.7539e-01,  1.8555e-01, -3.2227e-02,  ...,  9.3750e-01,\n",
       "            1.5078e+00,  7.8125e-01],\n",
       "          [ 2.0703e-01,  5.6641e-02, -7.8125e-02,  ...,  8.0859e-01,\n",
       "            7.8516e-01,  7.9688e-01],\n",
       "          [ 1.0645e-01,  1.0742e-01, -6.3477e-02,  ...,  1.7031e+00,\n",
       "            1.4766e+00,  8.1250e-01]],\n",
       "\n",
       "         [[ 1.8799e-02, -1.5564e-02,  3.7193e-05,  ...,  1.1641e+00,\n",
       "           -1.2061e-01,  1.9824e-01],\n",
       "          [-7.1094e-01,  7.8125e-01,  6.3281e-01,  ..., -4.7812e+00,\n",
       "            8.3203e-01, -1.3516e+00],\n",
       "          [ 3.7109e-01,  3.4180e-01, -4.7461e-01,  ..., -5.4688e+00,\n",
       "            9.1406e-01, -6.9141e-01],\n",
       "          ...,\n",
       "          [ 1.6602e-02, -2.5586e-01,  4.4727e-01,  ..., -5.5625e+00,\n",
       "           -1.4844e+00, -1.4355e-01],\n",
       "          [-8.2812e-01,  4.1992e-01,  1.0156e-01,  ..., -6.1250e+00,\n",
       "           -4.8438e-01,  1.7480e-01],\n",
       "          [-1.0156e+00,  7.9688e-01,  2.0898e-01,  ..., -5.0938e+00,\n",
       "           -2.0312e-01, -1.6172e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.2012e-02, -1.1475e-02,  7.3853e-03,  ...,  8.2031e-01,\n",
       "           -8.0078e-02,  7.0312e-02],\n",
       "          [ 1.1328e+00,  5.4688e-01,  4.7461e-01,  ..., -2.3125e+00,\n",
       "           -1.6797e-01,  1.4062e+00],\n",
       "          [ 4.1797e-01,  7.6562e-01,  9.9121e-02,  ..., -3.2969e+00,\n",
       "           -2.5586e-01, -2.6367e-01],\n",
       "          ...,\n",
       "          [-4.2578e-01, -7.5000e-01, -2.7466e-02,  ..., -3.2031e+00,\n",
       "           -4.5703e-01,  7.8906e-01],\n",
       "          [ 2.4023e-01,  8.7891e-02, -2.0605e-01,  ..., -3.7969e+00,\n",
       "           -3.7891e-01, -7.3828e-01],\n",
       "          [ 5.6250e-01, -2.9492e-01, -2.8906e-01,  ..., -2.9531e+00,\n",
       "           -3.0859e-01, -4.9316e-02]],\n",
       "\n",
       "         [[ 2.5757e-02,  3.3691e-02, -3.5400e-02,  ...,  8.9844e-02,\n",
       "           -9.0820e-02, -1.3977e-02],\n",
       "          [-5.9375e-01, -4.4922e-01,  9.0234e-01,  ..., -1.2188e+00,\n",
       "           -7.6562e-01, -6.5234e-01],\n",
       "          [-1.8945e-01, -7.4219e-02,  9.6191e-02,  ..., -1.2500e+00,\n",
       "            1.6016e+00, -6.0547e-01],\n",
       "          ...,\n",
       "          [ 4.5898e-01,  5.5469e-01, -1.7700e-02,  ..., -1.7031e+00,\n",
       "            2.8320e-01,  2.0410e-01],\n",
       "          [-4.3750e-01,  1.4453e-01, -9.5703e-02,  ..., -1.4062e-01,\n",
       "            2.0469e+00,  6.9922e-01],\n",
       "          [-4.4141e-01, -8.0469e-01, -1.9141e-01,  ..., -1.2500e+00,\n",
       "            1.1641e+00,  1.7480e-01]],\n",
       "\n",
       "         [[ 2.6245e-02,  1.7456e-02, -1.0071e-03,  ...,  7.0312e-01,\n",
       "           -6.9922e-01, -5.9766e-01],\n",
       "          [-2.2852e-01,  2.2070e-01,  4.3945e-01,  ..., -4.7188e+00,\n",
       "            1.8516e+00,  9.8438e-01],\n",
       "          [-4.1016e-02,  4.3555e-01,  1.6113e-01,  ..., -3.7031e+00,\n",
       "            1.5469e+00,  1.9766e+00],\n",
       "          ...,\n",
       "          [-1.1719e-01, -7.0312e-02, -1.0742e-01,  ..., -5.0312e+00,\n",
       "            7.1562e+00,  4.4062e+00],\n",
       "          [-4.2578e-01, -2.5391e-01, -1.3672e-01,  ..., -4.6875e+00,\n",
       "            5.8125e+00,  2.2969e+00],\n",
       "          [-6.3672e-01, -3.6914e-01,  2.0020e-01,  ..., -5.8594e-01,\n",
       "            1.3359e+00,  5.5312e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 2.5940e-03, -5.1270e-03,  8.6975e-04,  ..., -4.1199e-03,\n",
       "           -1.6861e-03,  4.6387e-03],\n",
       "          [ 3.4180e-01,  6.8848e-02,  5.4688e-01,  ...,  4.3164e-01,\n",
       "           -1.9922e-01, -1.0986e-02],\n",
       "          [ 2.5000e-01,  1.2988e-01, -1.8457e-01,  ..., -2.7734e-01,\n",
       "           -5.5469e-01,  4.1016e-02],\n",
       "          ...,\n",
       "          [ 1.2109e-01, -7.1106e-03, -9.8633e-02,  ..., -3.0151e-02,\n",
       "            4.3555e-01,  2.0215e-01],\n",
       "          [ 1.0205e-01,  1.7480e-01, -1.6895e-01,  ..., -2.0605e-01,\n",
       "            2.5195e-01, -3.0396e-02],\n",
       "          [-5.7068e-03,  1.6016e-01,  5.4688e-01,  ...,  1.7188e-01,\n",
       "           -1.9922e-01,  2.0996e-01]],\n",
       "\n",
       "         [[ 7.9346e-04, -5.9128e-04, -3.6430e-04,  ..., -6.9275e-03,\n",
       "            2.8076e-03, -2.1973e-03],\n",
       "          [ 2.2656e-01,  2.5977e-01,  2.4219e-01,  ...,  1.9238e-01,\n",
       "            3.3398e-01, -4.6631e-02],\n",
       "          [-4.7607e-03,  2.4414e-01, -4.1260e-02,  ...,  2.2461e-01,\n",
       "            2.6562e-01, -2.4902e-01],\n",
       "          ...,\n",
       "          [ 6.1719e-01,  2.6953e-01,  2.8931e-02,  ...,  2.7734e-01,\n",
       "            7.8125e-02,  1.4551e-01],\n",
       "          [ 1.6602e-01,  1.7188e-01, -1.7285e-01,  ..., -4.9744e-03,\n",
       "            4.6143e-02,  6.6895e-02],\n",
       "          [-2.7344e-01, -1.4941e-01, -1.6113e-01,  ..., -9.6191e-02,\n",
       "           -4.3457e-02, -2.9492e-01]],\n",
       "\n",
       "         [[-1.0452e-03,  5.2795e-03, -1.0132e-02,  ...,  2.7466e-03,\n",
       "           -5.4626e-03,  1.9684e-03],\n",
       "          [ 1.8066e-01,  4.6289e-01,  3.0859e-01,  ..., -2.2461e-01,\n",
       "           -2.9297e-01, -2.0508e-01],\n",
       "          [ 8.8379e-02, -1.5820e-01, -4.6680e-01,  ..., -3.5156e-02,\n",
       "           -7.2327e-03,  3.9258e-01],\n",
       "          ...,\n",
       "          [ 2.5586e-01, -1.4404e-02, -4.2969e-01,  ..., -1.4160e-01,\n",
       "            3.2617e-01, -6.8359e-02],\n",
       "          [-3.0469e-01,  2.1582e-01, -7.4707e-02,  ...,  2.3926e-01,\n",
       "            1.2500e-01, -3.7109e-01],\n",
       "          [ 8.7891e-02,  1.7383e-01,  2.4414e-01,  ...,  9.5703e-02,\n",
       "            3.6719e-01,  1.3086e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.2561e-03,  4.2725e-03,  5.7983e-03,  ..., -2.5330e-03,\n",
       "           -4.6082e-03, -1.4019e-04],\n",
       "          [ 1.2354e-01,  1.5234e-01, -3.4766e-01,  ...,  3.4668e-02,\n",
       "           -1.0803e-02, -5.8594e-01],\n",
       "          [-1.0449e-01, -2.2168e-01,  1.9141e-01,  ..., -2.7148e-01,\n",
       "            1.2988e-01,  2.2266e-01],\n",
       "          ...,\n",
       "          [-7.3242e-02,  6.9922e-01, -3.5352e-01,  ...,  9.0942e-03,\n",
       "            4.7852e-01,  3.8086e-01],\n",
       "          [-2.5977e-01,  1.9922e-01,  5.6250e-01,  ...,  1.2402e-01,\n",
       "            7.0312e-02, -1.1816e-01],\n",
       "          [-2.8906e-01, -1.2598e-01,  3.5742e-01,  ...,  9.2773e-02,\n",
       "           -1.8750e-01, -5.1172e-01]],\n",
       "\n",
       "         [[-4.5776e-03, -3.7842e-03,  6.9885e-03,  ...,  1.1902e-03,\n",
       "           -5.0049e-03,  9.6130e-04],\n",
       "          [-5.9570e-02,  6.3672e-01, -5.0781e-01,  ...,  1.8066e-01,\n",
       "           -1.8359e-01, -1.5625e-02],\n",
       "          [-4.4189e-02,  3.6719e-01, -2.0117e-01,  ..., -4.2188e-01,\n",
       "            2.0703e-01,  4.4434e-02],\n",
       "          ...,\n",
       "          [ 3.9978e-03, -3.9062e-01,  4.0039e-01,  ...,  7.0312e-01,\n",
       "           -7.3730e-02, -3.7842e-02],\n",
       "          [-1.7944e-02, -3.4570e-01, -2.7539e-01,  ...,  3.5889e-02,\n",
       "           -5.3711e-02, -2.0898e-01],\n",
       "          [-3.8867e-01, -6.1279e-02, -9.5215e-02,  ...,  2.4805e-01,\n",
       "            3.2227e-01,  3.8672e-01]],\n",
       "\n",
       "         [[-2.7313e-03,  3.6621e-03,  4.6997e-03,  ..., -2.9755e-03,\n",
       "            3.4180e-03,  4.1199e-03],\n",
       "          [-1.4221e-02,  1.8066e-02, -3.8281e-01,  ...,  7.2327e-03,\n",
       "           -1.8164e-01, -1.4832e-02],\n",
       "          [-2.4902e-01, -5.2979e-02, -2.2754e-01,  ...,  5.1514e-02,\n",
       "            8.5938e-02, -6.3965e-02],\n",
       "          ...,\n",
       "          [-1.5137e-01,  2.1240e-02, -2.8229e-03,  ...,  1.8188e-02,\n",
       "            3.9551e-02,  1.2598e-01],\n",
       "          [-1.0156e-01,  2.3926e-01,  3.4570e-01,  ...,  4.4556e-03,\n",
       "            1.1328e-01, -7.0312e-02],\n",
       "          [-1.2988e-01,  1.8066e-01,  2.8198e-02,  ..., -2.0020e-02,\n",
       "           -9.1309e-02,  7.9102e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.3000e-02,  4.8828e-02,  2.6978e-02,  ...,  4.2969e-02,\n",
       "            9.7168e-02,  1.9043e-01],\n",
       "          [ 1.9219e+00,  3.1836e-01, -7.8125e-01,  ...,  9.5312e-01,\n",
       "           -9.1016e-01, -6.3672e-01],\n",
       "          [ 2.5781e-01,  4.1406e-01, -1.6699e-01,  ..., -6.8750e-01,\n",
       "            5.6152e-02,  1.2109e+00],\n",
       "          ...,\n",
       "          [-2.0703e-01, -5.2734e-02,  5.6250e-01,  ..., -1.9922e-01,\n",
       "            9.2578e-01,  1.5918e-01],\n",
       "          [ 9.2578e-01, -9.6875e-01, -2.5000e-01,  ...,  4.5508e-01,\n",
       "           -5.8984e-01, -2.2969e+00],\n",
       "          [ 1.3438e+00, -4.0430e-01,  5.6641e-02,  ..., -4.6631e-02,\n",
       "           -1.8359e-01, -4.8828e-01]],\n",
       "\n",
       "         [[ 1.1673e-03, -1.4709e-02, -8.1177e-03,  ...,  7.6172e-01,\n",
       "            2.1289e-01,  4.8633e-01],\n",
       "          [ 2.7148e-01,  3.7109e-02,  3.6719e-01,  ..., -1.4844e+00,\n",
       "           -1.0234e+00, -4.7266e-01],\n",
       "          [ 7.8613e-02,  1.3574e-01,  6.4844e-01,  ..., -1.0781e+00,\n",
       "           -1.7422e+00, -8.5938e-01],\n",
       "          ...,\n",
       "          [ 7.6660e-02,  4.7852e-02, -5.7422e-01,  ..., -1.2344e+00,\n",
       "           -6.1328e-01,  3.4961e-01],\n",
       "          [ 1.5430e-01,  4.6094e-01, -1.4062e-01,  ..., -2.2812e+00,\n",
       "            1.0625e+00, -3.2043e-03],\n",
       "          [ 1.8262e-01,  1.1328e-01,  7.5391e-01,  ..., -2.4375e+00,\n",
       "            1.7109e+00,  4.0234e-01]],\n",
       "\n",
       "         [[-8.9111e-03, -4.7852e-02,  5.7373e-02,  ..., -5.0049e-02,\n",
       "           -1.5234e-01,  1.5039e-01],\n",
       "          [-2.5391e-01, -1.6602e-01,  3.3203e-01,  ...,  4.4727e-01,\n",
       "            4.9609e-01,  1.1953e+00],\n",
       "          [ 1.2085e-02, -2.6855e-03,  1.7188e-01,  ...,  6.9580e-03,\n",
       "            7.4609e-01,  4.6484e-01],\n",
       "          ...,\n",
       "          [-9.4238e-02,  3.9307e-02,  1.0010e-02,  ...,  1.6895e-01,\n",
       "            1.1426e-01,  1.7578e-01],\n",
       "          [ 1.5820e-01,  6.5625e-01,  2.8516e-01,  ...,  7.5000e-01,\n",
       "            6.0938e-01, -1.9922e+00],\n",
       "          [-1.5723e-01,  5.5908e-02,  2.1680e-01,  ..., -7.6562e-01,\n",
       "           -1.3770e-01, -7.7344e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.8433e-02, -1.1169e-02,  2.4048e-02,  ..., -1.5332e-01,\n",
       "           -2.2852e-01, -1.6211e-01],\n",
       "          [ 3.6719e-01,  2.0410e-01,  8.2422e-01,  ...,  1.5859e+00,\n",
       "            9.3359e-01,  4.0234e-01],\n",
       "          [ 1.6504e-01,  3.6133e-01, -2.6562e-01,  ...,  1.3672e+00,\n",
       "            1.8516e+00,  7.7734e-01],\n",
       "          ...,\n",
       "          [-3.5547e-01, -4.3945e-01,  3.2031e-01,  ...,  1.4609e+00,\n",
       "            2.3906e+00,  8.2520e-02],\n",
       "          [-2.2461e-01, -5.0000e-01, -2.8809e-02,  ...,  2.7930e-01,\n",
       "            1.4141e+00, -6.2109e-01],\n",
       "          [-8.7891e-02, -4.0430e-01,  1.0156e-01,  ...,  3.8086e-01,\n",
       "            2.1406e+00, -9.6094e-01]],\n",
       "\n",
       "         [[-1.1902e-02, -3.2471e-02,  2.3682e-02,  ..., -2.6172e-01,\n",
       "           -4.2578e-01, -3.9648e-01],\n",
       "          [ 6.3281e-01, -1.1328e-01, -3.9258e-01,  ...,  1.3359e+00,\n",
       "            1.3203e+00, -4.5703e-01],\n",
       "          [-1.7871e-01, -1.0986e-03, -2.9297e-01,  ...,  7.8516e-01,\n",
       "            1.1094e+00,  3.4961e-01],\n",
       "          ...,\n",
       "          [-5.5078e-01, -5.5469e-01,  5.3906e-01,  ..., -3.8672e-01,\n",
       "            6.1328e-01, -7.6172e-01],\n",
       "          [-1.6968e-02,  3.5547e-01,  6.5234e-01,  ...,  3.1250e-01,\n",
       "            1.1016e+00,  1.4375e+00],\n",
       "          [-4.9609e-01,  1.0391e+00,  7.3047e-01,  ...,  4.7656e-01,\n",
       "            2.0000e+00,  7.1484e-01]],\n",
       "\n",
       "         [[ 1.7395e-03, -2.7832e-02,  2.7618e-03,  ..., -1.9531e-01,\n",
       "           -4.5703e-01, -1.3086e-01],\n",
       "          [-1.1562e+00,  1.2793e-01, -8.9355e-02,  ...,  9.1797e-01,\n",
       "            1.8047e+00, -1.8828e+00],\n",
       "          [ 2.3926e-01,  2.3633e-01, -4.6387e-02,  ...,  9.2969e-01,\n",
       "            2.2812e+00, -8.9062e-01],\n",
       "          ...,\n",
       "          [-3.7500e-01, -8.7109e-01, -3.4570e-01,  ..., -3.6094e+00,\n",
       "            1.0312e+00, -1.5078e+00],\n",
       "          [-5.6641e-01, -5.1562e-01,  7.7344e-01,  ..., -6.9922e-01,\n",
       "            2.0938e+00, -9.0234e-01],\n",
       "          [-9.2969e-01, -5.6250e-01, -7.3242e-03,  ..., -3.1641e-01,\n",
       "            1.0078e+00,  1.8262e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 2.4261e-03, -2.0752e-03,  2.5482e-03,  ...,  5.4016e-03,\n",
       "           -1.1520e-03, -2.9602e-03],\n",
       "          [ 9.0820e-02, -2.1387e-01,  2.3730e-01,  ..., -3.0859e-01,\n",
       "            2.6367e-01, -7.9102e-02],\n",
       "          [ 1.2891e-01,  1.8066e-01,  3.5742e-01,  ..., -1.1426e-01,\n",
       "           -1.3965e-01,  1.3379e-01],\n",
       "          ...,\n",
       "          [ 1.4062e-01,  6.6895e-02, -3.5938e-01,  ...,  8.7402e-02,\n",
       "           -2.0020e-01,  1.1169e-02],\n",
       "          [-4.3213e-02, -9.8145e-02, -3.6133e-01,  ..., -2.2095e-02,\n",
       "           -8.9355e-02,  4.1016e-01],\n",
       "          [-2.8906e-01, -5.0964e-03, -1.6406e-01,  ..., -1.7676e-01,\n",
       "           -2.0410e-01,  2.5977e-01]],\n",
       "\n",
       "         [[-3.2959e-03, -8.6060e-03,  6.6223e-03,  ...,  7.4387e-04,\n",
       "            6.7444e-03,  5.7068e-03],\n",
       "          [-3.3398e-01,  4.7461e-01, -3.1445e-01,  ..., -1.1279e-01,\n",
       "            2.2363e-01,  5.1172e-01],\n",
       "          [-4.2969e-01,  3.7305e-01, -1.2695e-01,  ...,  2.3438e-01,\n",
       "           -6.9824e-02,  9.9609e-02],\n",
       "          ...,\n",
       "          [-1.3477e-01, -9.1797e-02,  6.7383e-02,  ...,  3.2422e-01,\n",
       "            3.8086e-01, -3.4961e-01],\n",
       "          [-6.8848e-02, -3.4180e-01,  3.3691e-02,  ..., -4.8633e-01,\n",
       "            1.4746e-01, -1.9531e-01],\n",
       "          [ 2.0996e-01, -1.0254e-01, -2.9663e-02,  ..., -3.6133e-01,\n",
       "            1.4746e-01,  3.0469e-01]],\n",
       "\n",
       "         [[ 1.4114e-03,  5.0659e-03,  1.0757e-03,  ..., -1.9836e-03,\n",
       "           -3.2349e-03,  1.7822e-02],\n",
       "          [ 1.3672e-01,  2.9688e-01,  2.7539e-01,  ..., -3.1836e-01,\n",
       "           -9.4238e-02, -8.6426e-02],\n",
       "          [-1.0645e-01, -1.8848e-01,  6.2891e-01,  ...,  7.4609e-01,\n",
       "           -2.3340e-01, -1.1292e-02],\n",
       "          ...,\n",
       "          [-2.5977e-01, -5.7068e-03, -3.0859e-01,  ...,  3.9258e-01,\n",
       "            4.7363e-02, -2.7734e-01],\n",
       "          [-9.4727e-02,  2.7734e-01, -1.2402e-01,  ..., -5.7031e-01,\n",
       "            2.1875e-01, -9.4141e-01],\n",
       "          [ 2.4512e-01, -2.8125e-01, -9.0332e-02,  ..., -6.6406e-02,\n",
       "            2.1777e-01, -7.0312e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 4.0588e-03, -2.1210e-03, -1.0452e-03,  ...,  3.2501e-03,\n",
       "            3.0518e-03, -3.0670e-03],\n",
       "          [-3.9844e-01,  6.7188e-01, -2.3315e-02,  ...,  4.8828e-01,\n",
       "            2.2852e-01,  1.1230e-01],\n",
       "          [-1.4062e-01,  4.1260e-02,  1.2158e-01,  ...,  1.0547e-01,\n",
       "           -9.4727e-02,  1.4453e-01],\n",
       "          ...,\n",
       "          [-5.1562e-01,  2.4023e-01,  5.5859e-01,  ...,  3.7842e-02,\n",
       "            1.0156e-01,  7.1289e-02],\n",
       "          [-4.3359e-01,  2.5195e-01, -8.3008e-02,  ..., -2.4707e-01,\n",
       "            1.5234e-01,  3.8672e-01],\n",
       "          [-4.8242e-01,  1.7578e-01, -1.7773e-01,  ..., -1.5625e-01,\n",
       "            6.2256e-03,  2.3340e-01]],\n",
       "\n",
       "         [[-3.3722e-03,  1.4801e-03, -6.0730e-03,  ..., -2.1606e-02,\n",
       "            7.7438e-04, -6.7520e-04],\n",
       "          [-1.8082e-03,  7.0312e-02,  1.6895e-01,  ...,  2.0996e-01,\n",
       "           -4.6631e-02, -1.4453e-01],\n",
       "          [ 1.1670e-01, -7.7148e-02, -4.8584e-02,  ...,  2.8516e-01,\n",
       "           -1.0791e-01,  1.5625e-01],\n",
       "          ...,\n",
       "          [ 1.5430e-01,  1.9531e-01, -2.0117e-01,  ..., -1.2256e-01,\n",
       "           -2.8711e-01,  1.2012e-01],\n",
       "          [ 2.7930e-01,  1.3574e-01, -3.9551e-02,  ..., -6.7871e-02,\n",
       "           -3.0859e-01,  3.4790e-03],\n",
       "          [ 1.4453e-01,  3.0273e-01,  3.3008e-01,  ...,  2.6562e-01,\n",
       "           -2.1484e-01,  2.0142e-02]],\n",
       "\n",
       "         [[ 7.0496e-03, -1.9897e-02, -5.4932e-03,  ...,  2.5330e-03,\n",
       "           -5.9509e-03, -5.0659e-03],\n",
       "          [-3.3984e-01, -1.3379e-01,  9.0820e-02,  ...,  8.9844e-02,\n",
       "            6.3965e-02, -2.6172e-01],\n",
       "          [-3.4570e-01,  3.9844e-01, -1.1377e-01,  ..., -4.1016e-02,\n",
       "            2.0215e-01, -1.8262e-01],\n",
       "          ...,\n",
       "          [-7.1289e-02,  2.5000e-01,  6.6757e-04,  ...,  3.5352e-01,\n",
       "            1.6968e-02, -9.6680e-02],\n",
       "          [ 1.4746e-01, -3.5156e-02, -2.3071e-02,  ..., -9.4727e-02,\n",
       "           -5.1514e-02,  5.8838e-02],\n",
       "          [ 1.7285e-01, -1.6992e-01,  1.3367e-02,  ...,  1.9727e-01,\n",
       "           -1.1816e-01, -1.5564e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-4.6875e-02, -6.3324e-04, -2.8076e-03,  ...,  1.9043e-02,\n",
       "           -2.1777e-01,  1.0620e-02],\n",
       "          [-7.2266e-02, -2.1484e-02,  2.3730e-01,  ...,  6.7188e-01,\n",
       "            1.2734e+00, -2.5586e-01],\n",
       "          [-3.7500e-01, -5.6641e-01,  4.4141e-01,  ..., -2.2949e-01,\n",
       "            2.2969e+00, -3.3984e-01],\n",
       "          ...,\n",
       "          [ 4.4531e-01,  3.4180e-01, -6.6406e-02,  ..., -2.4316e-01,\n",
       "            1.5547e+00,  1.0625e+00],\n",
       "          [ 2.1094e-01,  1.0469e+00, -6.3281e-01,  ..., -1.7812e+00,\n",
       "            8.4375e-01,  2.1387e-01],\n",
       "          [-2.8516e-01,  2.1875e-01, -7.2656e-01,  ..., -3.8672e-01,\n",
       "            6.1328e-01,  9.4141e-01]],\n",
       "\n",
       "         [[-8.1177e-03,  3.0273e-02, -2.4658e-02,  ...,  4.2114e-03,\n",
       "           -1.3916e-02, -1.4893e-02],\n",
       "          [ 6.6406e-02, -5.5469e-01,  1.6602e-02,  ...,  1.3281e+00,\n",
       "           -1.1328e+00,  7.4609e-01],\n",
       "          [ 1.0693e-01, -1.0059e-01, -1.4062e-01,  ..., -5.0781e-01,\n",
       "           -8.0859e-01,  1.2734e+00],\n",
       "          ...,\n",
       "          [-2.4707e-01,  5.3125e-01, -7.6660e-02,  ...,  4.4727e-01,\n",
       "            1.2500e+00,  1.9297e+00],\n",
       "          [ 2.0020e-01,  3.0664e-01,  2.5977e-01,  ...,  6.5234e-01,\n",
       "            7.7344e-01,  1.0547e+00],\n",
       "          [ 1.6406e-01,  2.0703e-01,  9.9121e-02,  ...,  5.9375e-01,\n",
       "            9.2969e-01,  1.0781e+00]],\n",
       "\n",
       "         [[ 5.3406e-03,  1.8311e-02, -5.4932e-03,  ..., -1.8750e-01,\n",
       "            3.6914e-01, -8.5156e-01],\n",
       "          [-6.2500e-01, -1.3281e-01,  5.3906e-01,  ...,  5.8594e-01,\n",
       "           -4.8828e-01,  2.1406e+00],\n",
       "          [ 6.0547e-01, -1.5625e+00, -4.8828e-01,  ...,  1.6094e+00,\n",
       "           -4.1406e-01,  2.8906e+00],\n",
       "          ...,\n",
       "          [-9.5703e-01,  1.2266e+00, -1.4453e-01,  ...,  1.7422e+00,\n",
       "           -1.9766e+00,  3.1875e+00],\n",
       "          [-1.0156e+00,  6.2500e-01,  4.1797e-01,  ...,  4.2773e-01,\n",
       "           -8.4375e-01,  3.2188e+00],\n",
       "          [ 1.2031e+00, -9.8633e-02,  8.5547e-01,  ...,  9.6094e-01,\n",
       "           -9.8828e-01,  2.7500e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1973e-02, -1.4709e-02,  1.4404e-02,  ..., -4.3359e-01,\n",
       "           -4.1016e-01,  7.4219e-02],\n",
       "          [-2.7930e-01,  2.2266e-01,  9.9121e-02,  ..., -7.8906e-01,\n",
       "            1.6953e+00,  2.0801e-01],\n",
       "          [-3.4668e-02,  9.2773e-02, -4.1992e-01,  ...,  1.2812e+00,\n",
       "            2.3438e+00,  4.5898e-02],\n",
       "          ...,\n",
       "          [ 2.5000e-01, -4.8584e-02, -4.3555e-01,  ...,  1.8906e+00,\n",
       "            1.2266e+00, -3.8086e-01],\n",
       "          [ 2.3242e-01, -5.4297e-01, -8.0078e-01,  ...,  1.0234e+00,\n",
       "            1.4922e+00, -1.2812e+00],\n",
       "          [-2.4707e-01, -1.0889e-01, -1.1426e-01,  ...,  2.9688e+00,\n",
       "            1.1172e+00, -1.2734e+00]],\n",
       "\n",
       "         [[-3.7842e-02,  2.0020e-02,  1.9150e-03,  ..., -2.1680e-01,\n",
       "           -1.0840e-01, -1.2012e-01],\n",
       "          [ 3.1250e-02, -8.0469e-01,  3.7500e-01,  ..., -6.4453e-02,\n",
       "           -1.8750e-01,  9.2578e-01],\n",
       "          [ 5.7031e-01,  3.3594e-01,  1.7188e-01,  ...,  2.1094e+00,\n",
       "           -3.4668e-02,  8.2422e-01],\n",
       "          ...,\n",
       "          [-3.8477e-01, -2.5586e-01,  3.5742e-01,  ...,  1.7656e+00,\n",
       "            1.2812e+00, -4.0625e-01],\n",
       "          [-1.7285e-01, -5.8594e-03, -5.1172e-01,  ...,  1.5547e+00,\n",
       "            1.5859e+00, -7.3438e-01],\n",
       "          [ 5.8594e-02,  2.4902e-02,  2.6758e-01,  ...,  1.1016e+00,\n",
       "            9.8438e-01, -1.2734e+00]],\n",
       "\n",
       "         [[ 3.2227e-02, -4.3030e-03,  3.6812e-04,  ...,  8.8672e-01,\n",
       "           -1.1768e-01,  7.4609e-01],\n",
       "          [ 3.7109e-01,  1.3379e-01, -1.3672e-01,  ..., -1.1172e+00,\n",
       "            2.2031e+00, -2.6250e+00],\n",
       "          [ 3.8672e-01,  5.6641e-02,  5.2979e-02,  ..., -1.9297e+00,\n",
       "            1.8125e+00, -2.6406e+00],\n",
       "          ...,\n",
       "          [-5.8594e-01, -7.0312e-02, -3.9258e-01,  ..., -2.0469e+00,\n",
       "            5.3516e-01, -2.6719e+00],\n",
       "          [-5.0391e-01,  1.6094e+00, -4.5312e-01,  ..., -3.1406e+00,\n",
       "           -1.2266e+00, -2.5625e+00],\n",
       "          [ 2.2461e-01,  6.3672e-01, -3.6133e-01,  ..., -2.4375e+00,\n",
       "            9.0234e-01, -3.2812e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-5.0964e-03, -2.4658e-02, -3.8910e-03,  ...,  9.3994e-03,\n",
       "            1.3733e-02, -1.1841e-02],\n",
       "          [ 2.1191e-01,  5.8203e-01,  2.9688e-01,  ...,  2.2070e-01,\n",
       "           -2.3560e-02,  5.3125e-01],\n",
       "          [-1.7871e-01, -1.0864e-02,  1.9043e-01,  ..., -2.2583e-02,\n",
       "           -1.0059e-01, -1.8262e-01],\n",
       "          ...,\n",
       "          [-3.2031e-01,  3.2812e-01, -5.5859e-01,  ..., -1.9531e-01,\n",
       "           -3.6523e-01, -3.1641e-01],\n",
       "          [-6.9141e-01,  1.4258e-01, -5.6641e-01,  ...,  8.9844e-02,\n",
       "            2.6172e-01, -7.6562e-01],\n",
       "          [-8.6328e-01,  1.3770e-01, -5.1562e-01,  ...,  2.2266e-01,\n",
       "            5.3516e-01, -1.0625e+00]],\n",
       "\n",
       "         [[-2.1362e-03,  1.2695e-02,  1.6327e-03,  ..., -1.0620e-02,\n",
       "           -8.4839e-03, -3.6621e-03],\n",
       "          [ 4.8047e-01, -2.4707e-01,  2.6172e-01,  ...,  3.3789e-01,\n",
       "            5.8203e-01, -6.7383e-02],\n",
       "          [ 1.7676e-01,  1.0938e-01,  1.7383e-01,  ..., -4.6143e-02,\n",
       "           -5.2979e-02, -2.4536e-02],\n",
       "          ...,\n",
       "          [-6.5234e-01,  1.2085e-02,  1.2500e-01,  ...,  2.7344e-01,\n",
       "           -1.9824e-01,  2.5586e-01],\n",
       "          [ 1.7578e-01, -2.2461e-01, -1.4355e-01,  ...,  2.1973e-01,\n",
       "           -3.0859e-01,  7.1289e-02],\n",
       "          [ 4.6094e-01, -5.8203e-01, -4.4531e-01,  ..., -4.3359e-01,\n",
       "            2.0410e-01,  1.0400e-01]],\n",
       "\n",
       "         [[ 2.8992e-03, -1.3809e-03, -1.1780e-02,  ...,  1.3809e-03,\n",
       "            1.9897e-02, -5.8289e-03],\n",
       "          [ 4.4727e-01, -2.9297e-01, -7.0312e-01,  ..., -7.8613e-02,\n",
       "            2.7148e-01,  1.0303e-01],\n",
       "          [-4.7070e-01,  2.3828e-01, -3.9258e-01,  ..., -3.8672e-01,\n",
       "           -2.7466e-02,  5.0049e-03],\n",
       "          ...,\n",
       "          [ 7.2266e-02,  5.8203e-01,  1.7822e-02,  ..., -8.5938e-01,\n",
       "           -4.1211e-01, -3.3789e-01],\n",
       "          [-1.4941e-01,  2.0312e-01, -3.6328e-01,  ..., -4.4922e-01,\n",
       "            7.3242e-02, -5.3125e-01],\n",
       "          [ 2.0801e-01,  4.4922e-01,  2.0215e-01,  ..., -5.8594e-01,\n",
       "            2.8516e-01, -1.7871e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.3994e-03, -1.1673e-03,  7.5378e-03,  ..., -2.5330e-03,\n",
       "           -7.4463e-03, -2.7847e-04],\n",
       "          [ 1.0498e-01, -5.0049e-02,  4.4141e-01,  ..., -1.9824e-01,\n",
       "           -1.0693e-01,  4.0039e-01],\n",
       "          [-3.5889e-02, -2.5195e-01,  3.1250e-01,  ..., -8.6426e-02,\n",
       "            5.7812e-01,  3.2617e-01],\n",
       "          ...,\n",
       "          [ 1.9336e-01, -4.1797e-01,  1.9824e-01,  ...,  2.5781e-01,\n",
       "            4.1992e-01, -6.2891e-01],\n",
       "          [ 5.8594e-01,  1.8164e-01, -2.0020e-01,  ..., -1.5723e-01,\n",
       "            1.6309e-01, -1.0742e-01],\n",
       "          [ 5.4297e-01, -1.7871e-01, -1.4355e-01,  ...,  2.5586e-01,\n",
       "            3.1433e-03, -2.2656e-01]],\n",
       "\n",
       "         [[-2.0447e-03,  3.0975e-03, -2.0905e-03,  ..., -3.9978e-03,\n",
       "           -5.3787e-04,  6.1340e-03],\n",
       "          [-3.5742e-01,  4.9072e-02,  2.6172e-01,  ...,  1.3867e-01,\n",
       "            1.7188e-01, -1.6992e-01],\n",
       "          [-4.1992e-01,  1.4258e-01,  6.0156e-01,  ...,  1.9727e-01,\n",
       "            1.1865e-01,  1.3574e-01],\n",
       "          ...,\n",
       "          [ 1.8750e-01,  3.7354e-02,  4.7266e-01,  ...,  3.5547e-01,\n",
       "           -4.3164e-01, -6.2256e-02],\n",
       "          [ 1.0352e-01, -3.6719e-01,  2.2363e-01,  ...,  2.4902e-01,\n",
       "           -1.6992e-01, -2.6758e-01],\n",
       "          [ 1.0376e-02, -4.0625e-01,  3.8281e-01,  ...,  2.4512e-01,\n",
       "            2.8442e-02,  2.3145e-01]],\n",
       "\n",
       "         [[-4.5166e-03, -8.8501e-03,  1.1780e-02,  ...,  4.6997e-03,\n",
       "           -1.4954e-03,  1.5030e-03],\n",
       "          [ 2.5000e-01,  1.2061e-01, -8.5449e-02,  ..., -1.2598e-01,\n",
       "            2.7734e-01, -3.3203e-01],\n",
       "          [ 1.1865e-01,  1.3086e-01, -1.2256e-01,  ...,  1.1035e-01,\n",
       "           -1.2451e-02, -9.4727e-02],\n",
       "          ...,\n",
       "          [-1.5332e-01,  1.6211e-01,  1.7090e-01,  ...,  1.8262e-01,\n",
       "           -2.0898e-01,  8.1055e-02],\n",
       "          [ 9.4727e-02,  1.0059e-01,  2.2461e-01,  ...,  4.7070e-01,\n",
       "           -3.1641e-01,  1.2158e-01],\n",
       "          [-1.0059e-01,  2.5781e-01, -3.3398e-01,  ...,  2.5781e-01,\n",
       "           -2.6172e-01,  7.0801e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-9.2773e-03, -1.3809e-03,  1.3580e-03,  ..., -1.2256e-01,\n",
       "           -2.7344e-01,  5.3125e-01],\n",
       "          [-2.4023e-01, -7.0312e-01,  2.4219e-01,  ...,  8.1641e-01,\n",
       "            8.5938e-01, -1.6641e+00],\n",
       "          [-6.5625e-01, -7.8516e-01,  1.0156e+00,  ...,  2.0781e+00,\n",
       "           -1.9434e-01, -1.9297e+00],\n",
       "          ...,\n",
       "          [-9.3750e-02,  5.1172e-01, -1.7676e-01,  ...,  5.4688e-01,\n",
       "           -9.6094e-01, -2.3125e+00],\n",
       "          [-1.2109e-01,  2.9297e-03, -1.6602e-02,  ..., -4.6094e-01,\n",
       "            1.0859e+00, -2.1406e+00],\n",
       "          [ 5.3125e-01, -5.2344e-01,  6.4844e-01,  ...,  9.4141e-01,\n",
       "           -4.1211e-01, -2.0781e+00]],\n",
       "\n",
       "         [[-6.3171e-03,  1.9409e-02, -9.7656e-03,  ...,  1.2656e+00,\n",
       "           -2.7930e-01, -3.4961e-01],\n",
       "          [ 2.4023e-01,  6.4062e-01,  4.7266e-01,  ..., -2.2500e+00,\n",
       "            1.1484e+00,  6.5234e-01],\n",
       "          [ 2.4512e-01,  1.4355e-01,  5.4932e-02,  ..., -2.0469e+00,\n",
       "            1.2188e+00,  8.5938e-01],\n",
       "          ...,\n",
       "          [-1.6797e-01, -4.1211e-01, -5.9375e-01,  ..., -3.6250e+00,\n",
       "            6.7969e-01, -7.2266e-01],\n",
       "          [-9.1406e-01, -7.8906e-01, -7.8516e-01,  ..., -4.7188e+00,\n",
       "           -2.0703e-01, -2.6094e+00],\n",
       "          [-4.6680e-01,  6.3477e-02, -4.4727e-01,  ..., -4.2188e+00,\n",
       "            9.2773e-02, -2.5312e+00]],\n",
       "\n",
       "         [[ 3.4912e-02, -4.1199e-03,  4.1992e-02,  ...,  6.0059e-02,\n",
       "           -5.9204e-03, -5.5908e-02],\n",
       "          [ 3.1055e-01,  2.1191e-01, -4.9609e-01,  ..., -3.1055e-01,\n",
       "            1.2188e+00,  1.0625e+00],\n",
       "          [ 3.4961e-01,  2.6367e-01, -7.4707e-02,  ...,  8.4961e-02,\n",
       "            1.1328e+00,  9.1016e-01],\n",
       "          ...,\n",
       "          [ 3.2227e-01,  1.8262e-01,  2.5000e-01,  ..., -5.1953e-01,\n",
       "            1.4766e+00,  5.4297e-01],\n",
       "          [-1.1914e-01,  7.0703e-01,  4.1992e-01,  ..., -5.9766e-01,\n",
       "            5.7031e-01,  6.2500e-01],\n",
       "          [-5.1562e-01, -9.7656e-04,  8.0078e-02,  ...,  1.0234e+00,\n",
       "            2.4062e+00, -3.9978e-03]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5747e-02,  4.2969e-02,  1.1063e-03,  ...,  1.5723e-01,\n",
       "            1.4160e-01, -1.0840e-01],\n",
       "          [ 3.2031e-01,  2.4023e-01,  1.7656e+00,  ...,  1.6484e+00,\n",
       "           -1.2109e+00,  1.3047e+00],\n",
       "          [ 3.6133e-01,  5.8594e-03,  5.9375e-01,  ...,  2.0625e+00,\n",
       "           -7.3828e-01, -7.1484e-01],\n",
       "          ...,\n",
       "          [-1.6602e-02,  3.6133e-01, -1.4648e-02,  ..., -1.5430e-01,\n",
       "           -3.9062e-01, -8.5547e-01],\n",
       "          [-2.5391e-02, -1.0078e+00, -8.6328e-01,  ...,  4.3701e-02,\n",
       "            3.1055e-01, -1.3047e+00],\n",
       "          [ 1.5137e-01,  8.7891e-03, -1.6602e-01,  ...,  3.2422e-01,\n",
       "           -2.2949e-01, -4.6875e-01]],\n",
       "\n",
       "         [[-3.6865e-02, -2.2949e-02,  1.0010e-02,  ..., -2.1289e-01,\n",
       "            3.1055e-01,  5.8203e-01],\n",
       "          [-1.7090e-01, -9.3262e-02,  1.9434e-01,  ..., -8.0566e-02,\n",
       "           -1.0400e-01, -2.0156e+00],\n",
       "          [ 3.3203e-02,  3.2227e-01,  2.4219e-01,  ...,  1.6094e+00,\n",
       "           -8.9453e-01, -4.4062e+00],\n",
       "          ...,\n",
       "          [-1.9922e-01,  7.1289e-02, -2.9297e-01,  ...,  1.9609e+00,\n",
       "           -1.1016e+00, -2.3438e+00],\n",
       "          [-7.9102e-02,  1.3965e-01, -3.7842e-02,  ...,  1.8594e+00,\n",
       "           -7.2656e-01, -2.6406e+00],\n",
       "          [ 4.8828e-04,  4.0820e-01,  3.4570e-01,  ...,  4.7852e-01,\n",
       "            1.9609e+00, -6.2891e-01]],\n",
       "\n",
       "         [[ 1.6846e-02,  1.1902e-02,  3.7842e-02,  ..., -2.3633e-01,\n",
       "           -1.2207e-01,  8.3496e-02],\n",
       "          [ 3.4668e-02,  1.5430e-01, -5.2734e-02,  ...,  7.1484e-01,\n",
       "           -1.1016e+00,  8.8379e-02],\n",
       "          [-3.1836e-01, -4.0234e-01, -5.6641e-01,  ...,  1.7734e+00,\n",
       "            1.7285e-01, -1.8262e-01],\n",
       "          ...,\n",
       "          [-7.8613e-02, -2.8076e-02,  3.0859e-01,  ...,  1.3594e+00,\n",
       "            2.0312e+00, -1.6094e+00],\n",
       "          [ 1.2305e-01, -5.8594e-01, -4.1016e-01,  ...,  2.9102e-01,\n",
       "            1.8164e-01, -6.1328e-01],\n",
       "          [-5.7031e-01, -1.7969e-01, -7.2266e-01,  ..., -2.1484e-01,\n",
       "            9.8047e-01, -8.6328e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-9.2163e-03, -2.2583e-02, -3.1891e-03,  ..., -1.5717e-03,\n",
       "           -4.9133e-03, -8.3618e-03],\n",
       "          [-7.4219e-01,  6.7969e-01, -1.4062e-01,  ..., -7.8516e-01,\n",
       "            6.7188e-01,  3.5938e-01],\n",
       "          [ 1.7285e-01,  3.3594e-01,  4.0820e-01,  ...,  1.5039e-01,\n",
       "            1.4551e-01,  2.6367e-01],\n",
       "          ...,\n",
       "          [-9.1309e-02,  4.3555e-01,  6.9824e-02,  ...,  1.0400e-01,\n",
       "            3.3188e-04, -4.1602e-01],\n",
       "          [-2.9883e-01, -2.4121e-01, -1.7773e-01,  ...,  1.0107e-01,\n",
       "            5.1172e-01,  2.9688e-01],\n",
       "          [ 2.5391e-01, -3.8086e-01,  2.8906e-01,  ..., -4.5898e-01,\n",
       "           -1.1426e-01,  3.8086e-01]],\n",
       "\n",
       "         [[ 1.6113e-02, -1.3351e-03,  2.8839e-03,  ...,  9.3384e-03,\n",
       "            2.6245e-03, -2.3651e-03],\n",
       "          [-5.1172e-01,  3.8330e-02,  1.4404e-02,  ..., -2.0801e-01,\n",
       "            1.4746e-01, -2.6172e-01],\n",
       "          [-1.4844e-01,  4.5898e-01, -3.6719e-01,  ..., -3.8867e-01,\n",
       "           -1.0645e-01, -2.3560e-02],\n",
       "          ...,\n",
       "          [ 1.9434e-01,  3.4570e-01,  7.9346e-03,  ..., -1.6113e-01,\n",
       "            1.5430e-01,  6.0938e-01],\n",
       "          [ 1.7480e-01,  8.2422e-01, -1.2012e-01,  ..., -4.1602e-01,\n",
       "            2.0801e-01,  2.3438e-01],\n",
       "          [-3.0078e-01,  1.0859e+00, -3.4570e-01,  ..., -6.0547e-01,\n",
       "           -4.9072e-02, -1.3184e-01]],\n",
       "\n",
       "         [[-5.6763e-03, -2.0752e-02,  6.7139e-03,  ...,  9.6191e-02,\n",
       "           -4.3106e-04, -6.2256e-03],\n",
       "          [ 3.7109e-01, -3.8086e-01,  2.1875e-01,  ..., -1.2188e+00,\n",
       "            2.3828e-01,  1.7578e-02],\n",
       "          [-2.7539e-01,  2.2754e-01, -4.9609e-01,  ..., -8.8281e-01,\n",
       "            5.8594e-01, -6.7578e-01],\n",
       "          ...,\n",
       "          [-1.8652e-01, -9.0332e-02, -1.5332e-01,  ..., -6.9922e-01,\n",
       "            6.7871e-02,  2.1973e-03],\n",
       "          [ 3.6621e-02, -5.9570e-02, -2.2949e-01,  ..., -1.9141e-01,\n",
       "           -3.5938e-01, -4.7461e-01],\n",
       "          [-5.0781e-01, -5.3906e-01, -3.5156e-01,  ..., -5.0391e-01,\n",
       "           -7.0801e-02, -2.5269e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-5.6458e-03, -1.0620e-02,  2.9449e-03,  ..., -1.9989e-03,\n",
       "            3.6163e-03,  2.2888e-03],\n",
       "          [ 6.6895e-02, -3.9258e-01,  2.8516e-01,  ...,  2.8125e-01,\n",
       "           -6.7969e-01,  5.5420e-02],\n",
       "          [ 3.4375e-01, -1.7090e-01,  4.6875e-01,  ...,  4.5312e-01,\n",
       "           -6.2109e-01, -7.5781e-01],\n",
       "          ...,\n",
       "          [ 2.8125e-01,  6.0547e-01,  2.6367e-01,  ...,  1.7285e-01,\n",
       "           -5.5469e-01, -2.1191e-01],\n",
       "          [ 1.4453e-01,  3.5352e-01,  2.0142e-02,  ...,  4.5898e-01,\n",
       "           -4.8633e-01,  3.6133e-02],\n",
       "          [ 4.4336e-01,  8.0566e-02,  7.0312e-02,  ...,  2.2949e-01,\n",
       "           -3.2031e-01,  2.5781e-01]],\n",
       "\n",
       "         [[-1.1658e-02,  5.7678e-03, -1.0300e-04,  ..., -9.8877e-03,\n",
       "           -1.0742e-02,  1.4191e-03],\n",
       "          [ 1.7578e-01, -4.3701e-02,  1.4844e-01,  ...,  2.7008e-03,\n",
       "            4.3164e-01,  1.3086e-01],\n",
       "          [ 7.4219e-02, -2.1875e-01,  3.1494e-02,  ...,  4.2383e-01,\n",
       "            2.5482e-03,  1.1816e-01],\n",
       "          ...,\n",
       "          [ 1.2354e-01,  7.1777e-02,  1.0693e-01,  ...,  7.2656e-01,\n",
       "            2.4121e-01, -2.4219e-01],\n",
       "          [ 8.4839e-03,  2.3438e-01,  2.7344e-01,  ...,  6.9531e-01,\n",
       "            1.1475e-01, -2.5000e-01],\n",
       "          [ 1.6602e-01, -1.0986e-01,  3.2812e-01,  ...,  1.3086e-01,\n",
       "            9.6191e-02, -3.6523e-01]],\n",
       "\n",
       "         [[ 1.9531e-03, -3.3417e-03, -5.4016e-03,  ...,  3.5553e-03,\n",
       "           -7.2327e-03, -6.0654e-04],\n",
       "          [ 2.7734e-01, -9.6436e-03, -3.7109e-01,  ...,  4.4141e-01,\n",
       "           -4.4922e-01, -3.1641e-01],\n",
       "          [-2.3340e-01,  2.9297e-01, -4.9072e-02,  ..., -3.5938e-01,\n",
       "           -2.6758e-01, -4.7461e-01],\n",
       "          ...,\n",
       "          [ 6.5234e-01, -7.0312e-02, -1.6403e-03,  ..., -7.3730e-02,\n",
       "            2.7930e-01, -2.3828e-01],\n",
       "          [ 5.4297e-01, -7.2266e-02,  7.4707e-02,  ..., -2.4805e-01,\n",
       "           -1.6797e-01, -9.9121e-02],\n",
       "          [ 4.3359e-01,  2.2754e-01,  6.4453e-01,  ...,  4.1260e-02,\n",
       "           -6.9922e-01,  7.6660e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.1494e-02,  7.8735e-03,  3.3203e-02,  ...,  8.3496e-02,\n",
       "            2.0996e-01,  4.3457e-02],\n",
       "          [ 4.0234e-01, -2.5781e-01, -5.4688e-02,  ..., -3.3984e-01,\n",
       "            1.9922e+00, -9.8438e-01],\n",
       "          [ 1.8848e-01, -2.8711e-01, -3.1250e-01,  ...,  5.9766e-01,\n",
       "            9.9609e-01, -2.8750e+00],\n",
       "          ...,\n",
       "          [-2.7539e-01, -7.9102e-02,  4.2969e-01,  ..., -3.4766e-01,\n",
       "            6.2891e-01,  1.8457e-01],\n",
       "          [ 1.8799e-02,  2.0312e-01,  4.1260e-02,  ...,  1.2451e-01,\n",
       "            8.8281e-01,  3.7891e-01],\n",
       "          [-9.9609e-02,  2.3047e-01,  1.4355e-01,  ...,  5.1953e-01,\n",
       "           -2.8320e-01,  2.4609e-01]],\n",
       "\n",
       "         [[-1.3428e-02,  5.6458e-03,  5.6152e-03,  ..., -1.6309e-01,\n",
       "           -1.2360e-03,  1.5381e-02],\n",
       "          [-1.3477e-01,  2.2168e-01,  7.1484e-01,  ..., -1.0303e-01,\n",
       "           -1.1562e+00,  5.1172e-01],\n",
       "          [ 6.2500e-01, -5.3516e-01,  6.5625e-01,  ..., -1.0859e+00,\n",
       "           -9.4141e-01, -3.8867e-01],\n",
       "          ...,\n",
       "          [-8.2812e-01, -6.2012e-02, -1.1523e-01,  ...,  9.3750e-01,\n",
       "           -1.1182e-01, -1.1641e+00],\n",
       "          [ 1.0547e-01, -1.9336e-01, -6.7578e-01,  ...,  8.0859e-01,\n",
       "           -1.4062e+00, -1.4609e+00],\n",
       "          [ 9.2578e-01, -2.1484e-01,  2.5195e-01,  ...,  4.2383e-01,\n",
       "           -6.8359e-01, -1.3516e+00]],\n",
       "\n",
       "         [[ 4.3213e-02, -3.7109e-02, -1.7578e-02,  ..., -1.0254e-01,\n",
       "           -7.1289e-02, -7.9590e-02],\n",
       "          [-2.1973e-03,  1.9629e-01, -8.7402e-02,  ...,  2.3750e+00,\n",
       "            2.3535e-01, -1.9219e+00],\n",
       "          [-5.6250e-01, -8.1543e-02,  7.8516e-01,  ...,  1.1875e+00,\n",
       "           -4.1016e-01,  9.6680e-02],\n",
       "          ...,\n",
       "          [ 5.5859e-01, -2.4902e-01, -2.1875e-01,  ..., -8.9453e-01,\n",
       "            9.1406e-01, -1.4453e+00],\n",
       "          [ 7.7734e-01, -1.3828e+00, -7.4609e-01,  ..., -5.8203e-01,\n",
       "            5.3516e-01, -2.9883e-01],\n",
       "          [-1.1523e-01, -1.7578e-02,  1.1719e-02,  ..., -3.4180e-01,\n",
       "            1.6094e+00, -5.8594e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.2773e-03,  2.5024e-02, -3.7689e-03,  ..., -1.7188e-01,\n",
       "           -1.2354e-01, -8.4229e-03],\n",
       "          [-2.1289e-01, -9.2773e-02, -1.5527e-01,  ...,  6.6406e-01,\n",
       "            1.8203e+00, -2.1094e-01],\n",
       "          [-1.6992e-01,  1.5039e-01,  4.0820e-01,  ..., -9.0942e-03,\n",
       "            1.5078e+00,  1.2598e-01],\n",
       "          ...,\n",
       "          [ 1.6113e-01, -4.5898e-02,  1.2031e+00,  ..., -9.4531e-01,\n",
       "            1.4453e+00, -6.7188e-01],\n",
       "          [-1.7773e-01,  1.2354e-01,  2.5781e-01,  ..., -4.7656e-01,\n",
       "            1.2656e+00,  9.7656e-02],\n",
       "          [ 4.5508e-01, -5.8203e-01, -2.8711e-01,  ..., -1.8047e+00,\n",
       "            4.0430e-01,  1.2109e-01]],\n",
       "\n",
       "         [[-4.8218e-03,  2.9755e-03,  5.4443e-02,  ..., -7.0312e-01,\n",
       "           -8.6060e-03,  4.9805e-02],\n",
       "          [-2.9785e-02, -9.7656e-03, -3.8867e-01,  ...,  3.3594e+00,\n",
       "           -4.0234e-01,  1.5332e-01],\n",
       "          [ 4.5703e-01,  5.4688e-01, -4.2188e-01,  ...,  3.3750e+00,\n",
       "            2.3730e-01, -1.7383e-01],\n",
       "          ...,\n",
       "          [ 4.1406e-01, -3.3789e-01, -3.7109e-01,  ...,  4.1562e+00,\n",
       "            1.0596e-01, -1.7031e+00],\n",
       "          [ 8.4766e-01, -4.2383e-01, -5.2344e-01,  ...,  4.9062e+00,\n",
       "           -7.1094e-01, -1.9844e+00],\n",
       "          [ 6.4062e-01, -5.7031e-01, -1.4062e-01,  ...,  3.3438e+00,\n",
       "           -1.8203e+00, -1.4160e-01]],\n",
       "\n",
       "         [[-7.3547e-03, -3.0273e-02,  1.8555e-02,  ..., -1.6562e+00,\n",
       "            9.0820e-02,  3.6133e-02],\n",
       "          [ 2.2168e-01,  9.2773e-02,  4.3359e-01,  ...,  3.8750e+00,\n",
       "            1.7266e+00,  8.3203e-01],\n",
       "          [ 2.4414e-02,  1.5625e-01,  2.7734e-01,  ...,  4.9062e+00,\n",
       "            2.3125e+00,  1.5547e+00],\n",
       "          ...,\n",
       "          [ 2.0996e-01, -3.6523e-01, -9.1406e-01,  ...,  5.2812e+00,\n",
       "            5.8203e-01,  5.1953e-01],\n",
       "          [-5.3125e-01, -2.2266e-01, -7.0312e-01,  ...,  5.1562e+00,\n",
       "            7.6562e-01, -3.5469e+00],\n",
       "          [ 2.2266e-01, -1.5625e-01, -1.5723e-01,  ...,  4.1875e+00,\n",
       "            3.1055e-01, -1.6562e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-5.7861e-02,  3.9978e-03,  3.6774e-03,  ...,  7.1106e-03,\n",
       "           -1.4877e-03,  3.5400e-02],\n",
       "          [ 3.7891e-01, -2.3438e-01, -1.9238e-01,  ...,  3.0078e-01,\n",
       "           -7.9590e-02,  1.1426e-01],\n",
       "          [ 7.6660e-02,  1.9922e-01,  6.1328e-01,  ..., -3.8281e-01,\n",
       "            9.7168e-02,  6.3281e-01],\n",
       "          ...,\n",
       "          [-2.6367e-01, -8.0078e-02,  2.1851e-02,  ..., -1.6113e-01,\n",
       "           -1.6113e-01, -2.2754e-01],\n",
       "          [-5.2344e-01, -1.8164e-01, -4.7070e-01,  ...,  1.0840e-01,\n",
       "           -3.5156e-02, -1.7456e-02],\n",
       "          [-9.4531e-01, -2.9883e-01, -1.2266e+00,  ..., -7.0703e-01,\n",
       "            2.4219e-01,  1.4062e-01]],\n",
       "\n",
       "         [[-5.0354e-03, -1.3489e-02, -7.5073e-03,  ..., -2.2827e-02,\n",
       "            1.2207e-04, -6.0120e-03],\n",
       "          [ 2.0605e-01, -1.9727e-01, -2.9492e-01,  ...,  1.3672e-01,\n",
       "           -7.6953e-01,  3.7956e-04],\n",
       "          [ 3.7500e-01,  7.9590e-02, -5.5908e-02,  ..., -9.0820e-02,\n",
       "           -5.5859e-01, -1.0352e-01],\n",
       "          ...,\n",
       "          [ 5.4688e-01, -1.3203e+00, -4.8438e-01,  ...,  1.4062e-01,\n",
       "            3.1250e-01, -3.0365e-03],\n",
       "          [ 3.0664e-01, -3.3594e-01, -5.8984e-01,  ...,  1.4062e-01,\n",
       "            5.8594e-01,  1.2354e-01],\n",
       "          [-1.5137e-01, -2.3926e-01, -7.9688e-01,  ..., -4.4189e-02,\n",
       "            3.8281e-01, -2.2363e-01]],\n",
       "\n",
       "         [[ 6.2561e-03,  7.7515e-03,  4.7302e-03,  ...,  1.3428e-02,\n",
       "           -6.7139e-03, -2.3193e-03],\n",
       "          [ 3.3203e-02, -4.7070e-01,  5.3711e-02,  ..., -2.1289e-01,\n",
       "           -3.5645e-02,  1.6211e-01],\n",
       "          [-1.6113e-01, -2.6562e-01,  3.0469e-01,  ..., -1.1108e-02,\n",
       "            9.2773e-02,  2.5195e-01],\n",
       "          ...,\n",
       "          [ 5.1953e-01,  8.6328e-01, -8.9355e-02,  ..., -7.4609e-01,\n",
       "            4.9805e-01,  5.2979e-02],\n",
       "          [ 7.3438e-01,  6.5625e-01, -1.4258e-01,  ..., -8.0469e-01,\n",
       "            3.9453e-01,  2.8906e-01],\n",
       "          [ 5.4297e-01,  2.2656e-01, -3.1055e-01,  ..., -1.0469e+00,\n",
       "            3.3203e-01,  7.4219e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.5635e-03,  1.6098e-03, -5.2261e-04,  ...,  4.8523e-03,\n",
       "            9.6436e-03, -3.7956e-04],\n",
       "          [ 1.6699e-01,  2.3926e-01, -9.7656e-02,  ...,  1.1084e-01,\n",
       "            4.0039e-01,  4.3750e-01],\n",
       "          [ 1.5137e-01,  9.3262e-02,  7.1289e-02,  ...,  1.5747e-02,\n",
       "            7.1777e-02,  4.3701e-02],\n",
       "          ...,\n",
       "          [-3.2617e-01, -1.1426e-01, -9.1797e-02,  ...,  4.9023e-01,\n",
       "            8.0859e-01, -3.7500e-01],\n",
       "          [-3.6719e-01, -1.5625e-01, -2.6953e-01,  ...,  1.1230e-01,\n",
       "            8.6719e-01, -7.4707e-02],\n",
       "          [-4.0430e-01,  1.8262e-01, -2.0215e-01,  ..., -1.0400e-01,\n",
       "           -1.3672e-01,  2.6367e-02]],\n",
       "\n",
       "         [[-4.7607e-03, -2.7771e-03,  7.5684e-03,  ...,  9.6512e-04,\n",
       "            3.9978e-03, -8.8501e-03],\n",
       "          [ 2.9688e-01, -2.2070e-01,  2.1777e-01,  ...,  8.1250e-01,\n",
       "            1.7871e-01,  2.0312e-01],\n",
       "          [-4.2578e-01, -3.8281e-01, -5.5859e-01,  ..., -4.0234e-01,\n",
       "            9.6680e-02, -3.9453e-01],\n",
       "          ...,\n",
       "          [ 1.8262e-01,  3.6719e-01, -7.1289e-02,  ...,  1.6992e-01,\n",
       "           -5.7031e-01,  1.9043e-01],\n",
       "          [ 4.9744e-03, -2.8809e-02, -7.6562e-01,  ..., -1.9629e-01,\n",
       "           -3.4961e-01,  1.9531e-01],\n",
       "          [ 3.9844e-01, -3.0859e-01, -7.4707e-02,  ..., -5.1758e-02,\n",
       "           -3.4570e-01, -4.3945e-01]],\n",
       "\n",
       "         [[-4.6387e-03, -1.4114e-03, -7.2327e-03,  ..., -8.8501e-03,\n",
       "            1.7929e-03, -3.5667e-04],\n",
       "          [-6.4941e-02,  5.8105e-02,  2.5391e-01,  ..., -1.4648e-01,\n",
       "           -1.5527e-01,  9.0332e-03],\n",
       "          [-1.7383e-01, -2.5977e-01,  4.9414e-01,  ...,  1.4355e-01,\n",
       "           -1.3184e-01,  9.9609e-02],\n",
       "          ...,\n",
       "          [-8.5156e-01, -1.2354e-01,  7.4707e-02,  ...,  4.8242e-01,\n",
       "            4.0820e-01,  4.3164e-01],\n",
       "          [-3.8477e-01, -9.5703e-02,  3.0664e-01,  ...,  3.9648e-01,\n",
       "            3.1836e-01,  5.3906e-01],\n",
       "          [-3.6133e-01,  1.3281e-01,  2.0312e-01,  ...,  1.0840e-01,\n",
       "            1.4893e-02,  1.6895e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.3977e-02,  5.7983e-03, -6.4392e-03,  ..., -1.5234e-01,\n",
       "            3.3203e-01,  5.4297e-01],\n",
       "          [-5.3906e-01,  9.3750e-02, -1.6699e-01,  ..., -2.8320e-01,\n",
       "           -1.3828e+00, -1.9629e-01],\n",
       "          [-1.5820e-01,  4.1016e-01,  1.5234e-01,  ..., -4.4336e-01,\n",
       "           -1.2793e-01,  4.0234e-01],\n",
       "          ...,\n",
       "          [ 8.3984e-02, -4.1406e-01,  1.5625e-02,  ..., -2.9492e-01,\n",
       "           -5.5859e-01, -1.5000e+00],\n",
       "          [-8.1543e-02, -1.9336e-01,  8.3496e-02,  ...,  5.2344e-01,\n",
       "           -7.5781e-01, -1.3438e+00],\n",
       "          [-3.5742e-01, -2.2754e-01,  1.9922e-01,  ..., -2.9144e-03,\n",
       "           -8.0078e-01, -2.0000e+00]],\n",
       "\n",
       "         [[ 1.3123e-02, -1.6479e-02,  2.7588e-02,  ..., -1.1523e-01,\n",
       "           -1.3867e-01,  4.7461e-01],\n",
       "          [ 1.0303e-01, -1.4746e-01, -6.9922e-01,  ...,  2.8125e-01,\n",
       "           -7.8516e-01, -4.6289e-01],\n",
       "          [-2.9297e-01,  1.2500e-01,  7.8613e-02,  ..., -1.8555e-01,\n",
       "            5.9375e-01, -8.2031e-01],\n",
       "          ...,\n",
       "          [ 2.5195e-01, -7.4219e-01,  7.3438e-01,  ..., -7.8906e-01,\n",
       "           -1.2109e-01,  2.9688e+00],\n",
       "          [-4.3359e-01, -7.0312e-01,  2.8516e-01,  ..., -2.3730e-01,\n",
       "            1.9043e-01,  1.5469e+00],\n",
       "          [-7.5391e-01,  8.3984e-02, -2.6562e-01,  ...,  4.8242e-01,\n",
       "            5.4297e-01,  8.0859e-01]],\n",
       "\n",
       "         [[ 1.7578e-02,  5.4932e-03, -3.3691e-02,  ..., -1.3672e-01,\n",
       "            3.1641e-01,  2.6953e-01],\n",
       "          [ 6.8750e-01,  3.0078e-01, -1.1914e-01,  ...,  7.2656e-01,\n",
       "           -7.7734e-01, -5.8203e-01],\n",
       "          [-3.4766e-01,  5.8203e-01, -1.2354e-01,  ..., -5.3516e-01,\n",
       "           -1.0312e+00, -9.0332e-02],\n",
       "          ...,\n",
       "          [ 4.4531e-01,  5.7031e-01,  2.1680e-01,  ..., -1.9727e-01,\n",
       "           -1.3281e+00, -1.3477e-01],\n",
       "          [-7.2266e-02,  9.8633e-02,  6.9531e-01,  ..., -1.4844e-01,\n",
       "           -2.4170e-02,  1.3672e+00],\n",
       "          [-3.0859e-01,  2.8711e-01,  8.5938e-01,  ..., -3.3398e-01,\n",
       "           -7.4609e-01,  1.9375e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.7994e-03,  2.1851e-02,  1.4465e-02,  ...,  1.3184e-01,\n",
       "           -2.0996e-01, -1.4941e-01],\n",
       "          [-1.6406e-01, -8.4961e-02, -3.7891e-01,  ..., -1.5991e-02,\n",
       "           -2.3438e-02,  7.0703e-01],\n",
       "          [ 1.6602e-01,  3.0469e-01, -1.5723e-01,  ...,  1.5547e+00,\n",
       "            1.0859e+00, -1.6022e-04],\n",
       "          ...,\n",
       "          [-5.3125e-01, -3.6523e-01, -3.3594e-01,  ...,  1.1953e+00,\n",
       "           -5.8203e-01, -6.7969e-01],\n",
       "          [-9.1406e-01, -3.8672e-01,  3.3984e-01,  ...,  1.3828e+00,\n",
       "            1.4062e-01, -4.1406e-01],\n",
       "          [ 1.2109e-01, -6.1328e-01, -2.7734e-01,  ...,  1.0986e-02,\n",
       "            1.1621e-01,  1.4453e+00]],\n",
       "\n",
       "         [[-6.1646e-03,  2.4872e-03, -2.7618e-03,  ...,  1.8066e-02,\n",
       "           -1.2256e-01, -3.5352e-01],\n",
       "          [-1.3672e-01, -1.3574e-01, -2.4609e-01,  ...,  6.2891e-01,\n",
       "           -6.5625e-01,  6.2891e-01],\n",
       "          [ 2.6758e-01,  4.2578e-01, -9.4727e-02,  ...,  9.2969e-01,\n",
       "            1.1406e+00,  3.6133e-01],\n",
       "          ...,\n",
       "          [ 4.9805e-02,  2.3145e-01,  4.0820e-01,  ...,  1.0625e+00,\n",
       "            1.3359e+00,  7.1484e-01],\n",
       "          [ 1.5332e-01,  2.5977e-01,  3.3203e-01,  ...,  7.4609e-01,\n",
       "            1.2500e+00,  2.6094e+00],\n",
       "          [-6.4453e-02,  6.5625e-01,  4.7266e-01,  ...,  7.8516e-01,\n",
       "           -5.1758e-02,  3.3594e+00]],\n",
       "\n",
       "         [[-2.1118e-02,  2.1118e-02,  1.5625e-02,  ...,  1.6562e+00,\n",
       "           -1.3574e-01, -2.0264e-02],\n",
       "          [ 2.5000e-01, -1.1035e-01,  5.0391e-01,  ..., -3.7969e+00,\n",
       "           -2.4512e-01, -1.2266e+00],\n",
       "          [ 1.3965e-01,  6.2988e-02,  2.0898e-01,  ..., -5.4375e+00,\n",
       "            1.6094e+00, -2.7812e+00],\n",
       "          ...,\n",
       "          [ 4.1992e-01, -1.2061e-01, -2.2266e-01,  ..., -4.6562e+00,\n",
       "           -5.0781e-01, -1.2812e+00],\n",
       "          [ 4.1406e-01, -4.0430e-01,  2.8906e-01,  ..., -3.6406e+00,\n",
       "           -8.9453e-01, -2.2188e+00],\n",
       "          [-6.9336e-02, -2.9492e-01, -2.1875e-01,  ..., -3.5312e+00,\n",
       "           -2.2031e+00,  1.1279e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.0681e-02,  4.0283e-03,  7.9956e-03,  ...,  1.1047e-02,\n",
       "            3.2501e-03,  5.3101e-03],\n",
       "          [ 4.0283e-02,  2.2705e-02,  9.3750e-02,  ..., -6.5625e-01,\n",
       "            7.4707e-02, -2.5781e-01],\n",
       "          [-2.3828e-01,  1.9629e-01,  2.2070e-01,  ...,  3.8086e-02,\n",
       "            2.7344e-01, -8.2812e-01],\n",
       "          ...,\n",
       "          [-1.3867e-01, -1.3086e-01, -1.2305e-01,  ...,  2.4316e-01,\n",
       "            3.3447e-02, -4.8828e-01],\n",
       "          [-2.0312e-01,  4.4531e-01,  3.2031e-01,  ...,  5.1172e-01,\n",
       "            3.1836e-01, -6.1719e-01],\n",
       "          [-3.3398e-01, -2.1680e-01,  6.2109e-01,  ...,  1.8848e-01,\n",
       "           -4.8828e-01, -2.4121e-01]],\n",
       "\n",
       "         [[-1.9287e-02,  2.5635e-02,  1.2268e-02,  ..., -1.2891e-01,\n",
       "            1.7166e-03,  3.0518e-03],\n",
       "          [ 2.7539e-01, -5.8838e-02,  1.5820e-01,  ..., -4.6875e-01,\n",
       "           -1.7480e-01,  1.2598e-01],\n",
       "          [-5.7861e-02, -2.1094e-01,  3.6865e-02,  ..., -5.6641e-01,\n",
       "           -1.9336e-01,  2.6245e-02],\n",
       "          ...,\n",
       "          [-9.2773e-02,  3.7109e-01, -3.0396e-02,  ...,  1.3281e+00,\n",
       "            1.5039e-01,  2.2852e-01],\n",
       "          [ 4.4336e-01,  2.4316e-01, -1.3867e-01,  ...,  8.5547e-01,\n",
       "            3.8867e-01,  4.3750e-01],\n",
       "          [ 1.3245e-02, -2.1729e-02,  2.8906e-01,  ...,  1.0781e+00,\n",
       "            2.8564e-02, -1.0498e-01]],\n",
       "\n",
       "         [[-1.0010e-02, -5.5237e-03, -1.1063e-03,  ...,  1.5076e-02,\n",
       "            1.3977e-02, -1.8188e-02],\n",
       "          [-1.3965e-01,  1.4746e-01, -1.6211e-01,  ..., -6.8750e-01,\n",
       "           -3.0664e-01,  3.1445e-01],\n",
       "          [ 4.6094e-01,  4.0039e-01, -1.2500e-01,  ..., -3.6719e-01,\n",
       "           -3.6133e-01, -8.8867e-02],\n",
       "          ...,\n",
       "          [-1.7773e-01,  9.0820e-02,  3.4570e-01,  ..., -2.5195e-01,\n",
       "            1.3965e-01,  4.3555e-01],\n",
       "          [-2.1777e-01, -3.4570e-01, -1.0547e-01,  ..., -1.0742e-01,\n",
       "           -7.8125e-03,  3.3722e-03],\n",
       "          [ 2.5977e-01, -3.9258e-01, -2.0264e-02,  ...,  2.6953e-01,\n",
       "            3.6523e-01, -1.2146e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.5858e-03, -2.7832e-02,  2.1729e-02,  ...,  2.7344e-02,\n",
       "           -5.6152e-03, -7.0496e-03],\n",
       "          [-9.7046e-03,  1.2695e-01,  2.0801e-01,  ...,  2.7734e-01,\n",
       "            3.1055e-01,  4.2188e-01],\n",
       "          [-1.3574e-01,  4.8096e-02,  8.7891e-01,  ..., -1.0010e-01,\n",
       "            1.7773e-01, -2.0801e-01],\n",
       "          ...,\n",
       "          [ 1.5332e-01, -2.9688e-01,  1.0449e-01,  ..., -2.0020e-01,\n",
       "           -3.1445e-01, -7.9956e-03],\n",
       "          [ 9.7168e-02, -1.2793e-01, -1.0791e-01,  ..., -3.7305e-01,\n",
       "           -1.8457e-01, -6.5918e-02],\n",
       "          [-3.9551e-02,  8.0566e-03, -1.6406e-01,  ...,  1.6113e-01,\n",
       "           -5.8203e-01,  3.0273e-01]],\n",
       "\n",
       "         [[-1.5488e-03, -1.1826e-03,  1.0498e-02,  ...,  2.3682e-02,\n",
       "           -1.5137e-02, -1.0498e-02],\n",
       "          [-9.2773e-02,  7.5391e-01, -1.3962e-03,  ...,  1.6357e-02,\n",
       "           -5.4297e-01, -8.5449e-02],\n",
       "          [ 3.5547e-01,  1.1172e+00, -2.9883e-01,  ...,  9.7168e-02,\n",
       "            1.9434e-01, -3.1445e-01],\n",
       "          ...,\n",
       "          [-5.9814e-02, -2.0996e-01, -8.8867e-02,  ...,  2.7539e-01,\n",
       "           -3.5645e-02,  2.7930e-01],\n",
       "          [ 8.2397e-03, -1.2158e-01, -1.2598e-01,  ..., -3.1006e-02,\n",
       "            3.1494e-02,  1.3477e-01],\n",
       "          [ 3.3398e-01,  1.7676e-01,  2.7148e-01,  ..., -7.3828e-01,\n",
       "           -2.5586e-01, -1.8652e-01]],\n",
       "\n",
       "         [[ 5.0049e-03,  7.4158e-03,  5.5695e-04,  ...,  5.5847e-03,\n",
       "           -9.3994e-03,  8.7891e-03],\n",
       "          [ 2.8711e-01, -2.0996e-01, -1.9141e-01,  ..., -8.2520e-02,\n",
       "            4.0283e-02, -6.1523e-02],\n",
       "          [-2.8711e-01,  2.0605e-01, -5.2734e-02,  ..., -2.0752e-02,\n",
       "            1.4355e-01,  1.5332e-01],\n",
       "          ...,\n",
       "          [-3.5156e-01,  1.2695e-01, -7.2266e-01,  ..., -1.2451e-01,\n",
       "           -9.6191e-02,  1.0449e-01],\n",
       "          [ 1.3977e-02,  1.9043e-01, -2.2559e-01,  ...,  2.4707e-01,\n",
       "           -2.0117e-01, -6.8359e-02],\n",
       "          [ 3.7305e-01,  4.0771e-02, -3.6133e-01,  ..., -2.7148e-01,\n",
       "           -5.7812e-01,  4.2578e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.9165e-02, -4.3030e-03,  4.9744e-03,  ..., -1.0400e-01,\n",
       "            4.5117e-01,  1.6719e+00],\n",
       "          [-5.1172e-01,  2.6562e-01, -1.0400e-01,  ...,  2.5000e+00,\n",
       "            7.2021e-03, -6.5000e+00],\n",
       "          [-6.8359e-01,  4.1211e-01, -5.6641e-02,  ...,  2.2969e+00,\n",
       "           -1.5312e+00, -7.4062e+00],\n",
       "          ...,\n",
       "          [ 6.4453e-01, -7.9688e-01, -5.1514e-02,  ...,  4.0625e+00,\n",
       "            1.9922e+00, -5.6562e+00],\n",
       "          [-5.7617e-02, -2.9883e-01, -1.2695e-01,  ...,  1.4922e+00,\n",
       "            5.8984e-01, -7.3125e+00],\n",
       "          [-5.8984e-01, -2.7148e-01, -5.0781e-01,  ...,  5.2734e-01,\n",
       "            1.1406e+00, -5.9375e+00]],\n",
       "\n",
       "         [[-1.8311e-02,  1.3062e-02, -9.5215e-03,  ..., -3.4766e-01,\n",
       "            2.1484e-01,  2.1387e-01],\n",
       "          [ 3.0859e-01,  2.4805e-01, -1.9727e-01,  ..., -2.5391e-01,\n",
       "            6.0938e-01,  8.9453e-01],\n",
       "          [ 2.0801e-01,  1.1621e-01,  1.0156e+00,  ..., -3.1836e-01,\n",
       "           -2.2344e+00,  6.6797e-01],\n",
       "          ...,\n",
       "          [ 7.5391e-01,  3.6328e-01,  1.5000e+00,  ...,  2.8711e-01,\n",
       "           -4.5898e-01, -1.3516e+00],\n",
       "          [ 4.5312e-01,  1.5234e-01,  6.2109e-01,  ...,  7.8516e-01,\n",
       "            1.9824e-01, -4.3750e-01],\n",
       "          [ 3.0664e-01, -3.5742e-01,  4.2578e-01,  ...,  5.3516e-01,\n",
       "           -1.4531e+00, -8.3594e-01]],\n",
       "\n",
       "         [[ 1.5320e-02,  2.0630e-02, -2.2339e-02,  ..., -7.9688e-01,\n",
       "            3.7695e-01,  3.9844e-01],\n",
       "          [-9.2188e-01, -3.9844e-01, -1.0791e-01,  ...,  2.3281e+00,\n",
       "           -2.2559e-01, -2.2031e+00],\n",
       "          [-5.3906e-01, -5.8203e-01,  5.5078e-01,  ...,  1.0469e+00,\n",
       "           -1.3984e+00, -3.6406e+00],\n",
       "          ...,\n",
       "          [ 2.6953e-01,  1.2500e+00, -1.2578e+00,  ...,  2.3145e-01,\n",
       "           -1.1875e+00, -2.9844e+00],\n",
       "          [ 5.1953e-01, -1.9824e-01, -8.3984e-01,  ...,  1.0469e+00,\n",
       "           -7.2266e-01, -1.0156e+00],\n",
       "          [-1.0596e-01,  7.6953e-01,  7.3047e-01,  ...,  2.2339e-02,\n",
       "           -3.2344e+00, -1.4609e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 8.1177e-03, -1.3000e-02, -8.4229e-03,  ..., -2.2583e-02,\n",
       "            2.2949e-01, -1.1641e+00],\n",
       "          [ 1.6797e-01, -1.5820e-01,  1.0352e-01,  ...,  7.5391e-01,\n",
       "            1.5547e+00,  2.3906e+00],\n",
       "          [-3.0469e-01,  6.1328e-01,  5.1172e-01,  ...,  1.0437e-02,\n",
       "            1.4922e+00,  2.5312e+00],\n",
       "          ...,\n",
       "          [-1.4648e-03, -5.9766e-01, -1.3574e-01,  ...,  7.8516e-01,\n",
       "            9.6680e-02,  4.2188e+00],\n",
       "          [ 2.3828e-01,  6.8054e-03, -5.5469e-01,  ...,  1.9062e+00,\n",
       "           -1.7500e+00,  4.3125e+00],\n",
       "          [ 1.9531e-01, -6.3672e-01,  4.5312e-01,  ...,  2.2031e+00,\n",
       "           -1.7188e+00,  4.4062e+00]],\n",
       "\n",
       "         [[-4.1260e-02, -2.7222e-02, -7.7248e-05,  ...,  4.3555e-01,\n",
       "            4.0234e-01, -1.7578e-01],\n",
       "          [ 3.3984e-01,  3.9258e-01,  1.0742e-01,  ..., -1.0000e+00,\n",
       "            1.2500e+00,  6.7578e-01],\n",
       "          [-7.5195e-02,  1.2578e+00,  5.7422e-01,  ..., -3.1406e+00,\n",
       "            1.7031e+00, -3.4766e-01],\n",
       "          ...,\n",
       "          [-3.2617e-01,  7.7734e-01,  5.8594e-02,  ...,  3.4531e+00,\n",
       "            8.6914e-02,  1.7656e+00],\n",
       "          [-6.6406e-01,  3.0078e-01, -9.2188e-01,  ...,  2.4844e+00,\n",
       "           -1.7773e-01,  4.8633e-01],\n",
       "          [ 5.8984e-01,  3.7305e-01, -5.1562e-01,  ...,  4.7188e+00,\n",
       "           -1.7969e+00,  9.1797e-02]],\n",
       "\n",
       "         [[-6.5002e-03,  1.3489e-02,  6.8283e-04,  ..., -7.2266e-01,\n",
       "           -4.4922e-01,  1.1182e-01],\n",
       "          [ 5.3125e-01, -2.7344e-01,  3.1055e-01,  ..., -7.7734e-01,\n",
       "            1.6641e+00, -3.8867e-01],\n",
       "          [ 3.1836e-01,  1.6602e-01,  1.8945e-01,  ...,  2.4805e-01,\n",
       "            1.6094e+00, -1.7422e+00],\n",
       "          ...,\n",
       "          [ 2.0020e-02,  1.8262e-01, -1.0059e-01,  ..., -6.8359e-02,\n",
       "            1.6406e+00, -7.7734e-01],\n",
       "          [ 1.9238e-01,  3.9062e-03, -1.2305e-01,  ...,  3.8477e-01,\n",
       "           -8.2812e-01, -1.4922e+00],\n",
       "          [ 1.8359e-01, -4.5703e-01, -4.6289e-01,  ..., -3.1055e-01,\n",
       "            8.6328e-01, -5.8984e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 3.4637e-03, -1.0803e-02, -3.2196e-03,  ...,  4.7913e-03,\n",
       "           -4.5166e-03, -3.9978e-03],\n",
       "          [-1.6895e-01,  7.5781e-01,  6.5234e-01,  ...,  1.6699e-01,\n",
       "            7.1875e-01,  9.0625e-01],\n",
       "          [-3.3789e-01,  2.5977e-01,  5.6250e-01,  ..., -3.0640e-02,\n",
       "            4.5117e-01,  6.6797e-01],\n",
       "          ...,\n",
       "          [ 6.7969e-01, -6.5918e-02,  6.2891e-01,  ..., -2.1191e-01,\n",
       "            1.0078e+00,  2.4658e-02],\n",
       "          [ 2.8906e-01, -1.9043e-01,  1.4844e-01,  ...,  9.3750e-02,\n",
       "            9.4531e-01,  1.8750e-01],\n",
       "          [ 5.8594e-01, -6.6016e-01, -2.7344e-01,  ...,  1.7676e-01,\n",
       "            6.0547e-01, -1.0400e-01]],\n",
       "\n",
       "         [[ 1.1902e-02,  5.5908e-02,  1.4771e-02,  ...,  4.0283e-02,\n",
       "           -4.5654e-02, -2.1484e-02],\n",
       "          [-5.3906e-01, -1.0234e+00, -3.9453e-01,  ..., -2.7539e-01,\n",
       "           -2.4316e-01, -1.0254e-01],\n",
       "          [-1.6699e-01, -7.5391e-01, -5.5859e-01,  ..., -6.6016e-01,\n",
       "           -3.8477e-01,  3.5156e-02],\n",
       "          ...,\n",
       "          [-3.7109e-01, -3.3008e-01, -2.9883e-01,  ..., -1.0547e-01,\n",
       "           -2.8711e-01,  4.4141e-01],\n",
       "          [-2.5781e-01, -2.5977e-01,  7.7637e-02,  ...,  5.1172e-01,\n",
       "           -2.9883e-01,  8.3984e-02],\n",
       "          [-3.9062e-01,  1.0352e-01,  3.7305e-01,  ...,  2.3145e-01,\n",
       "            2.1680e-01,  1.9922e-01]],\n",
       "\n",
       "         [[-4.4861e-03,  2.0630e-02, -2.0386e-02,  ...,  4.5166e-03,\n",
       "            1.0071e-02,  5.8289e-03],\n",
       "          [-4.8828e-02, -1.0596e-01,  9.1797e-02,  ..., -1.7188e-01,\n",
       "           -2.1094e-01,  1.6113e-01],\n",
       "          [-2.5977e-01, -9.0332e-03, -7.3730e-02,  ..., -5.1514e-02,\n",
       "           -4.7852e-01,  1.4551e-01],\n",
       "          ...,\n",
       "          [-1.4551e-01,  3.4375e-01,  3.0273e-01,  ..., -3.0469e-01,\n",
       "           -8.2031e-01, -4.8242e-01],\n",
       "          [-1.7188e-01,  2.9297e-02,  2.4219e-01,  ..., -1.7188e-01,\n",
       "           -4.3750e-01, -8.7402e-02],\n",
       "          [ 1.3574e-01,  3.7305e-01, -4.5654e-02,  ..., -1.3245e-02,\n",
       "           -8.8672e-01,  3.2031e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 8.9722e-03,  8.1177e-03,  1.3123e-02,  ...,  1.3245e-02,\n",
       "           -6.2256e-03,  1.1169e-02],\n",
       "          [-8.6426e-02, -2.7344e-01, -8.3008e-02,  ...,  2.5586e-01,\n",
       "           -2.4023e-01, -2.4707e-01],\n",
       "          [ 9.1797e-02, -6.4453e-01, -1.7090e-01,  ...,  2.4121e-01,\n",
       "           -3.1445e-01, -4.2773e-01],\n",
       "          ...,\n",
       "          [-5.3711e-02, -1.1016e+00, -3.5938e-01,  ..., -6.5430e-02,\n",
       "            4.1504e-02,  2.8320e-01],\n",
       "          [ 4.2236e-02, -5.5469e-01, -3.3984e-01,  ..., -1.7090e-02,\n",
       "           -3.1836e-01,  5.7812e-01],\n",
       "          [ 2.2852e-01, -3.5938e-01,  2.8320e-01,  ..., -5.9814e-02,\n",
       "           -1.1084e-01,  4.4727e-01]],\n",
       "\n",
       "         [[-7.8125e-03,  3.3112e-03,  1.0376e-02,  ..., -8.4839e-03,\n",
       "           -1.5625e-02, -1.2112e-04],\n",
       "          [ 1.7548e-03, -4.0625e-01,  1.9727e-01,  ...,  2.0996e-02,\n",
       "            2.4292e-02,  2.2266e-01],\n",
       "          [-1.8164e-01, -2.3047e-01, -3.3594e-01,  ..., -2.2949e-01,\n",
       "            2.5391e-01, -7.8125e-02],\n",
       "          ...,\n",
       "          [-3.8281e-01, -2.8125e-01, -8.2520e-02,  ..., -9.2773e-02,\n",
       "           -1.0938e-01,  1.9238e-01],\n",
       "          [-7.7637e-02, -1.1230e-01, -3.1055e-01,  ...,  6.3965e-02,\n",
       "           -2.3926e-01,  1.6895e-01],\n",
       "          [-8.9844e-02,  1.9043e-01, -2.3145e-01,  ...,  6.0059e-02,\n",
       "           -2.4414e-01, -1.3379e-01]],\n",
       "\n",
       "         [[ 5.6152e-03, -6.9885e-03, -2.1362e-02,  ..., -1.8539e-03,\n",
       "           -3.0060e-03,  4.6997e-03],\n",
       "          [ 4.1260e-02,  1.8652e-01,  2.6953e-01,  ...,  4.8438e-01,\n",
       "            7.1094e-01, -1.1865e-01],\n",
       "          [ 1.7383e-01,  7.1875e-01, -2.8931e-02,  ...,  2.5195e-01,\n",
       "            5.1172e-01,  1.9922e-01],\n",
       "          ...,\n",
       "          [ 2.0898e-01,  3.5938e-01,  4.2725e-02,  ..., -1.5039e-01,\n",
       "            1.2109e-01, -2.2754e-01],\n",
       "          [ 2.9688e-01,  4.8438e-01,  1.5137e-02,  ..., -2.9175e-02,\n",
       "            3.4570e-01,  2.8198e-02],\n",
       "          [ 4.3945e-01,  6.2891e-01,  8.8379e-02,  ...,  4.9023e-01,\n",
       "            6.6406e-01, -1.4551e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.4221e-02, -7.5684e-03,  1.8188e-02,  ...,  1.3203e+00,\n",
       "            4.0039e-02,  3.9844e-01],\n",
       "          [-6.3477e-03, -2.8516e-01, -2.4902e-01,  ..., -3.2422e-01,\n",
       "           -2.0938e+00, -6.8750e-01],\n",
       "          [-1.2207e-01,  4.3701e-02,  2.5146e-02,  ..., -1.8594e+00,\n",
       "           -3.2656e+00,  1.1182e-01],\n",
       "          ...,\n",
       "          [ 7.5000e-01, -3.2031e-01,  6.3281e-01,  ..., -1.7578e+00,\n",
       "           -1.6641e+00, -2.3281e+00],\n",
       "          [ 1.8203e+00,  8.7891e-03,  3.7305e-01,  ..., -3.2656e+00,\n",
       "           -1.0078e+00, -2.0156e+00],\n",
       "          [ 7.0312e-02,  1.1816e-01, -2.8711e-01,  ..., -1.1875e+00,\n",
       "           -8.3594e-01, -6.2500e-01]],\n",
       "\n",
       "         [[-1.1841e-02, -1.3672e-02,  3.0670e-03,  ..., -2.4902e-01,\n",
       "           -1.6113e-01, -1.3770e-01],\n",
       "          [-1.6797e-01, -8.5938e-02,  7.3047e-01,  ..., -2.7148e-01,\n",
       "            2.0898e-01, -2.0312e+00],\n",
       "          [-1.8594e+00,  3.3008e-01, -8.2031e-02,  ...,  5.8203e-01,\n",
       "            2.2031e+00, -1.2812e+00],\n",
       "          ...,\n",
       "          [ 1.3984e+00, -8.0469e-01, -5.8594e-03,  ...,  4.1602e-01,\n",
       "            3.9062e-01,  4.2578e-01],\n",
       "          [ 1.0938e+00, -3.4180e-01,  5.0781e-02,  ...,  1.3281e+00,\n",
       "            1.6719e+00,  1.7500e+00],\n",
       "          [-2.3828e-01,  2.1484e-01,  1.2500e-01,  ..., -1.1484e+00,\n",
       "            1.0859e+00,  3.5889e-02]],\n",
       "\n",
       "         [[-1.7578e-02, -1.1047e-02, -2.7618e-03,  ...,  2.9492e-01,\n",
       "            7.2266e-02, -6.2109e-01],\n",
       "          [-3.4766e-01, -3.5938e-01, -9.9121e-02,  ..., -2.7148e-01,\n",
       "           -1.0312e+00, -1.4766e+00],\n",
       "          [-3.3203e-02,  2.3828e-01,  1.3184e-02,  ..., -5.0781e-01,\n",
       "           -1.7500e+00, -1.5391e+00],\n",
       "          ...,\n",
       "          [ 1.2500e-01,  2.8711e-01,  3.4424e-02,  ...,  3.3789e-01,\n",
       "           -1.1797e+00, -3.7891e-01],\n",
       "          [ 4.4922e-02, -8.7500e-01, -9.4727e-02,  ..., -7.4219e-01,\n",
       "           -6.9922e-01,  8.0078e-01],\n",
       "          [-1.4453e-01, -3.3203e-01, -4.1211e-01,  ..., -2.2500e+00,\n",
       "           -2.4219e+00, -4.4141e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.4604e-03,  1.2329e-02,  2.3041e-03,  ...,  1.8555e-01,\n",
       "            2.3906e+00, -1.8921e-02],\n",
       "          [ 2.7148e-01, -2.4609e-01,  2.5391e-02,  ..., -1.7109e+00,\n",
       "           -1.6016e+00, -3.4570e-01],\n",
       "          [ 3.1641e-01, -3.5352e-01, -3.1641e-01,  ..., -2.9219e+00,\n",
       "           -4.5625e+00, -6.0156e-01],\n",
       "          ...,\n",
       "          [-1.1426e-01, -4.9219e-01,  4.1406e-01,  ...,  5.6641e-01,\n",
       "           -3.4062e+00, -2.2188e+00],\n",
       "          [-6.9141e-01,  9.5703e-02,  3.2031e-01,  ...,  7.5195e-02,\n",
       "           -4.6250e+00, -2.8125e+00],\n",
       "          [ 3.0859e-01,  2.7539e-01,  1.0312e+00,  ..., -1.5156e+00,\n",
       "           -4.0938e+00, -3.2656e+00]],\n",
       "\n",
       "         [[ 3.1494e-02,  7.6599e-03,  2.8076e-03,  ..., -1.0078e+00,\n",
       "            1.1016e+00, -8.3984e-01],\n",
       "          [-8.2422e-01, -1.8848e-01, -2.3438e-01,  ..., -1.6328e+00,\n",
       "            2.5469e+00, -1.6113e-01],\n",
       "          [-3.5742e-01,  2.5586e-01, -1.7822e-02,  ..., -6.9141e-01,\n",
       "            8.7891e-02, -2.6367e-01],\n",
       "          ...,\n",
       "          [ 2.1680e-01, -6.6016e-01,  4.1992e-01,  ...,  1.6406e+00,\n",
       "            1.1719e+00, -1.0938e+00],\n",
       "          [-3.8086e-01,  3.0273e-01,  1.0234e+00,  ...,  2.9219e+00,\n",
       "           -4.1406e-01, -9.2969e-01],\n",
       "          [ 4.7070e-01, -7.2656e-01, -1.4160e-01,  ...,  2.8438e+00,\n",
       "            6.6016e-01,  9.6094e-01]],\n",
       "\n",
       "         [[ 3.9795e-02, -5.9326e-02, -1.4587e-02,  ..., -1.2988e-01,\n",
       "           -2.1094e-01, -5.3223e-02],\n",
       "          [ 6.1328e-01, -1.1523e-01,  2.4316e-01,  ...,  3.0625e+00,\n",
       "            2.1484e-01,  1.0781e+00],\n",
       "          [-4.6143e-02, -2.6758e-01,  1.4038e-02,  ...,  1.4219e+00,\n",
       "            2.4062e+00, -2.5195e-01],\n",
       "          ...,\n",
       "          [-2.4316e-01,  3.3203e-01,  3.0078e-01,  ...,  1.4141e+00,\n",
       "            2.7812e+00,  9.3750e-01],\n",
       "          [ 1.6992e-01,  3.1641e-01,  5.3516e-01,  ...,  1.8047e+00,\n",
       "            1.5391e+00,  1.7344e+00],\n",
       "          [-3.8086e-01, -1.7285e-01,  2.3340e-01,  ...,  1.8984e+00,\n",
       "            1.9688e+00,  6.4062e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 9.6436e-03,  2.2095e-02, -6.7444e-03,  ..., -3.4142e-04,\n",
       "            3.1982e-02, -9.9487e-03],\n",
       "          [-3.9648e-01, -7.8125e-02,  5.3125e-01,  ...,  1.1658e-02,\n",
       "            2.6367e-01,  7.0703e-01],\n",
       "          [ 6.8054e-03,  3.8086e-02,  7.6953e-01,  ..., -9.9609e-02,\n",
       "            2.8516e-01,  3.3008e-01],\n",
       "          ...,\n",
       "          [ 1.0596e-01,  5.6396e-02, -7.0703e-01,  ...,  2.7734e-01,\n",
       "            9.2188e-01,  8.6426e-02],\n",
       "          [-5.7617e-02, -6.9824e-02, -6.4941e-02,  ..., -1.4941e-01,\n",
       "            8.4473e-02, -3.4424e-02],\n",
       "          [-2.0312e-01, -1.1182e-01,  4.9805e-01,  ..., -1.9434e-01,\n",
       "           -8.1055e-02, -1.5137e-01]],\n",
       "\n",
       "         [[-9.7656e-02, -8.5449e-02,  6.7383e-02,  ...,  1.7944e-02,\n",
       "            2.4292e-02, -7.5195e-02],\n",
       "          [ 1.4746e-01, -3.6719e-01, -3.8281e-01,  ..., -1.6504e-01,\n",
       "           -3.8330e-02, -1.1230e-01],\n",
       "          [ 2.0801e-01, -2.0898e-01, -9.3750e-02,  ...,  1.3281e-01,\n",
       "           -9.5703e-02,  3.5352e-01],\n",
       "          ...,\n",
       "          [-1.3672e-01,  1.7383e-01,  8.6426e-02,  ...,  3.5742e-01,\n",
       "           -2.3633e-01,  9.2773e-02],\n",
       "          [ 1.0010e-01, -8.0490e-04, -4.9219e-01,  ...,  2.9883e-01,\n",
       "           -1.0156e-01,  3.3008e-01],\n",
       "          [ 3.4180e-02,  7.3242e-02,  1.0376e-02,  ...,  5.0391e-01,\n",
       "           -1.0303e-01,  7.6660e-02]],\n",
       "\n",
       "         [[ 6.9824e-02,  5.0537e-02, -4.8828e-02,  ...,  3.5889e-02,\n",
       "            2.4780e-02,  2.4536e-02],\n",
       "          [-3.9648e-01, -5.1172e-01,  4.4531e-01,  ...,  3.1641e-01,\n",
       "            3.5742e-01, -1.7773e-01],\n",
       "          [-5.5664e-02, -8.4375e-01,  1.8848e-01,  ..., -5.4297e-01,\n",
       "            9.7266e-01,  2.6562e-01],\n",
       "          ...,\n",
       "          [-1.6602e-01, -4.4678e-02,  4.1992e-01,  ..., -3.9453e-01,\n",
       "            2.1973e-01, -1.1094e+00],\n",
       "          [-7.4707e-02,  5.3711e-02,  8.7891e-01,  ..., -6.0425e-03,\n",
       "            3.5352e-01, -4.1406e-01],\n",
       "          [ 5.1953e-01, -4.8047e-01, -4.9219e-01,  ...,  4.7266e-01,\n",
       "            1.5527e-01, -7.8125e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 7.4463e-03, -3.3203e-02,  7.0953e-04,  ..., -4.0588e-03,\n",
       "            2.1484e-02, -2.2339e-02],\n",
       "          [ 2.0312e-01, -6.9922e-01,  2.8516e-01,  ...,  9.9609e-02,\n",
       "           -1.8945e-01,  2.0020e-01],\n",
       "          [ 2.3242e-01, -1.7090e-01,  4.9133e-03,  ..., -3.7109e-02,\n",
       "            2.5195e-01,  1.1279e-01],\n",
       "          ...,\n",
       "          [ 3.9062e-01, -5.6250e-01,  3.9453e-01,  ..., -3.6621e-02,\n",
       "           -3.2812e-01, -3.7891e-01],\n",
       "          [ 4.8438e-01, -4.0234e-01,  8.1787e-03,  ..., -9.1797e-02,\n",
       "           -2.2461e-01, -2.5977e-01],\n",
       "          [-1.9043e-01, -7.7344e-01,  2.6562e-01,  ..., -2.8906e-01,\n",
       "           -3.2422e-01,  2.0898e-01]],\n",
       "\n",
       "         [[-2.9297e-02,  2.2095e-02, -2.8687e-03,  ...,  1.9226e-03,\n",
       "           -8.8501e-03,  1.6724e-02],\n",
       "          [ 5.4321e-03,  5.1953e-01,  2.7344e-01,  ..., -1.2878e-02,\n",
       "           -2.0898e-01, -3.1836e-01],\n",
       "          [ 2.3828e-01,  3.4570e-01, -6.7749e-03,  ...,  5.6885e-02,\n",
       "            7.8613e-02, -8.8867e-02],\n",
       "          ...,\n",
       "          [ 1.7383e-01,  4.8047e-01,  8.1250e-01,  ...,  1.4160e-01,\n",
       "           -4.5508e-01,  1.8652e-01],\n",
       "          [-1.0547e-01,  2.9688e-01,  4.5898e-01,  ..., -3.1836e-01,\n",
       "           -2.6172e-01,  3.3594e-01],\n",
       "          [ 4.1406e-01,  1.0791e-01, -3.4766e-01,  ...,  2.0898e-01,\n",
       "           -1.1279e-01,  4.6289e-01]],\n",
       "\n",
       "         [[ 1.5182e-03,  7.0190e-03,  1.1108e-02,  ...,  1.7822e-02,\n",
       "           -1.1963e-02,  4.5586e-04],\n",
       "          [-2.6758e-01, -2.8125e-01,  7.1875e-01,  ...,  2.5586e-01,\n",
       "            1.1572e-01, -1.0303e-01],\n",
       "          [-2.8125e-01, -5.5859e-01,  4.6875e-01,  ..., -2.5977e-01,\n",
       "           -3.3594e-01, -2.9883e-01],\n",
       "          ...,\n",
       "          [-1.2891e-01, -3.2227e-01, -2.3535e-01,  ...,  2.0215e-01,\n",
       "            1.3086e-01,  4.5410e-02],\n",
       "          [-8.5938e-02, -5.5420e-02,  1.2598e-01,  ..., -1.4355e-01,\n",
       "            9.1309e-02,  5.6641e-01],\n",
       "          [-2.7148e-01, -2.9883e-01,  8.9844e-01,  ...,  3.1055e-01,\n",
       "            4.9219e-01,  4.5117e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 8.4229e-03, -2.1484e-02, -8.3923e-04,  ...,  8.2520e-02,\n",
       "           -1.3672e-01,  2.3926e-02],\n",
       "          [ 2.3828e-01, -2.2363e-01,  3.3203e-01,  ...,  5.7031e-01,\n",
       "           -6.3281e-01,  1.4453e-01],\n",
       "          [ 4.4336e-01, -5.0781e-01, -4.0527e-02,  ..., -9.1406e-01,\n",
       "           -1.6562e+00, -3.0664e-01],\n",
       "          ...,\n",
       "          [-2.5391e-01, -3.2031e-01,  1.5625e-01,  ...,  3.1128e-03,\n",
       "            2.3730e-01,  7.0312e-01],\n",
       "          [-4.9805e-01,  3.5547e-01, -3.0078e-01,  ..., -4.5166e-02,\n",
       "           -8.6060e-03, -2.8711e-01],\n",
       "          [ 6.2500e-01,  4.1016e-01,  8.1055e-02,  ..., -2.7734e-01,\n",
       "            8.1641e-01, -6.7383e-02]],\n",
       "\n",
       "         [[ 2.7222e-02,  2.0504e-04,  2.8442e-02,  ...,  2.4219e-01,\n",
       "            1.7969e-01,  2.4023e-01],\n",
       "          [ 4.4922e-01,  1.9141e-01, -1.4844e-01,  ...,  4.1992e-01,\n",
       "           -5.5078e-01, -4.9219e-01],\n",
       "          [ 8.5938e-01, -4.3750e-01, -4.4922e-01,  ..., -3.9062e-01,\n",
       "           -1.6094e+00, -5.6250e-01],\n",
       "          ...,\n",
       "          [-1.7773e-01, -2.1387e-01, -3.2617e-01,  ...,  1.2734e+00,\n",
       "            1.4404e-02, -2.5156e+00],\n",
       "          [-1.8359e-01, -8.4961e-02, -1.8457e-01,  ..., -1.8750e+00,\n",
       "           -1.2188e+00, -1.2188e+00],\n",
       "          [ 5.9766e-01,  9.9609e-01, -5.1562e-01,  ..., -1.3359e+00,\n",
       "           -9.6484e-01,  6.1328e-01]],\n",
       "\n",
       "         [[-2.1729e-02,  5.1880e-03,  1.7456e-02,  ..., -1.0234e+00,\n",
       "           -1.9336e-01, -2.0625e+00],\n",
       "          [-2.6172e-01,  1.7773e-01,  1.3281e-01,  ...,  2.1875e+00,\n",
       "           -1.7188e+00,  2.8438e+00],\n",
       "          [-3.2422e-01,  7.0801e-02,  6.9275e-03,  ...,  2.8750e+00,\n",
       "           -1.5156e+00,  4.1562e+00],\n",
       "          ...,\n",
       "          [ 1.3184e-01,  1.4844e-01,  3.3398e-01,  ...,  4.6875e+00,\n",
       "           -3.7031e+00,  3.9531e+00],\n",
       "          [ 3.3008e-01,  3.6914e-01, -2.6562e-01,  ...,  3.7500e+00,\n",
       "           -8.7891e-01,  5.8750e+00],\n",
       "          [ 3.1445e-01,  1.2109e-01,  7.4219e-02,  ...,  2.4062e+00,\n",
       "           -4.3750e-01,  5.4375e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1362e-02,  1.6968e-02, -1.2451e-02,  ...,  1.9434e-01,\n",
       "            4.4678e-02,  1.2402e-01],\n",
       "          [-4.2188e-01, -4.7852e-01,  2.0752e-03,  ..., -1.5547e+00,\n",
       "           -1.8047e+00, -2.2656e+00],\n",
       "          [ 2.7344e-02, -5.1562e-01, -1.4258e-01,  ..., -1.7422e+00,\n",
       "           -2.5312e+00, -2.5312e+00],\n",
       "          ...,\n",
       "          [ 2.7588e-02, -5.1270e-03,  5.4688e-01,  ..., -2.0781e+00,\n",
       "           -1.6250e+00, -1.3828e+00],\n",
       "          [-1.2695e-01,  4.1211e-01, -3.7695e-01,  ..., -9.8828e-01,\n",
       "           -1.7734e+00, -1.9141e+00],\n",
       "          [-2.4902e-01,  3.5352e-01, -3.1738e-02,  ...,  2.7734e-01,\n",
       "           -1.4297e+00,  1.0625e+00]],\n",
       "\n",
       "         [[-1.5503e-02, -4.4861e-03,  3.5706e-03,  ...,  2.4512e-01,\n",
       "            2.6562e+00,  3.1982e-02],\n",
       "          [ 2.3438e-01,  1.8262e-01,  5.1172e-01,  ...,  3.7305e-01,\n",
       "           -3.1406e+00, -1.5430e-01],\n",
       "          [-4.6875e-01, -2.7734e-01,  5.9375e-01,  ..., -5.7983e-03,\n",
       "           -6.0312e+00,  2.6489e-02],\n",
       "          ...,\n",
       "          [ 4.2578e-01, -9.7656e-04,  3.6133e-01,  ..., -2.2500e+00,\n",
       "           -3.8594e+00, -9.3750e-01],\n",
       "          [ 2.3828e-01, -3.7500e-01,  4.5898e-01,  ..., -2.5781e+00,\n",
       "           -5.0000e+00, -1.5234e+00],\n",
       "          [-5.5078e-01, -4.1797e-01,  1.6406e-01,  ..., -2.2500e+00,\n",
       "           -3.4062e+00, -3.2969e+00]],\n",
       "\n",
       "         [[ 6.4087e-03,  6.9427e-04,  2.1484e-02,  ..., -1.5234e-01,\n",
       "           -1.3574e-01,  1.9824e-01],\n",
       "          [-1.6602e-01, -5.8594e-02, -6.5234e-01,  ..., -1.3359e+00,\n",
       "           -5.5859e-01, -7.2266e-01],\n",
       "          [ 1.8311e-03,  3.4961e-01, -5.1953e-01,  ..., -2.3125e+00,\n",
       "           -2.2031e+00, -4.1504e-02],\n",
       "          ...,\n",
       "          [-2.1289e-01,  5.3906e-01,  6.7578e-01,  ..., -2.2656e+00,\n",
       "           -1.1172e+00, -2.7656e+00],\n",
       "          [-1.4844e-01, -1.8457e-01,  1.6602e-01,  ..., -2.0625e+00,\n",
       "            5.0537e-02, -2.1562e+00],\n",
       "          [-3.5156e-02, -4.6094e-01,  6.4453e-02,  ..., -2.0781e+00,\n",
       "            1.1328e+00, -3.2969e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.6113e-02,  3.0762e-02,  1.4221e-02,  ..., -6.0425e-03,\n",
       "           -6.5308e-03, -1.5198e-02],\n",
       "          [-3.5352e-01, -4.2773e-01,  3.9648e-01,  ...,  5.0000e-01,\n",
       "            2.1362e-02,  2.1484e-01],\n",
       "          [ 1.8799e-02, -8.4375e-01,  1.5137e-01,  ...,  5.8350e-02,\n",
       "           -1.3965e-01, -5.3223e-02],\n",
       "          ...,\n",
       "          [-3.6523e-01, -5.7031e-01,  5.2246e-02,  ...,  3.8086e-01,\n",
       "           -9.4531e-01, -2.1240e-02],\n",
       "          [-1.1230e-02, -8.2031e-02,  1.2891e-01,  ...,  5.9766e-01,\n",
       "           -6.8750e-01, -4.4922e-01],\n",
       "          [-1.5918e-01,  7.2754e-02,  3.0469e-01,  ...,  2.5977e-01,\n",
       "            2.1582e-01, -2.0898e-01]],\n",
       "\n",
       "         [[-2.3193e-02, -5.8289e-03, -9.8877e-03,  ..., -1.5869e-02,\n",
       "            5.0735e-04,  2.5146e-02],\n",
       "          [ 2.9688e-01, -5.5176e-02, -2.2559e-01,  ..., -2.1362e-02,\n",
       "            3.9307e-02, -1.1523e-01],\n",
       "          [ 2.4609e-01,  1.1230e-01,  1.5527e-01,  ...,  1.5747e-02,\n",
       "            1.5918e-01,  2.8906e-01],\n",
       "          ...,\n",
       "          [ 3.7305e-01, -3.1250e-01, -2.8711e-01,  ...,  1.0449e-01,\n",
       "           -2.5000e-01, -1.9336e-01],\n",
       "          [ 1.6211e-01,  1.7969e-01,  2.0410e-01,  ...,  4.1260e-02,\n",
       "           -1.3867e-01,  1.3379e-01],\n",
       "          [-1.4551e-01,  6.2500e-01,  1.9141e-01,  ..., -3.9795e-02,\n",
       "            1.7773e-01, -6.7871e-02]],\n",
       "\n",
       "         [[-3.3264e-03, -1.8066e-02,  1.0803e-02,  ...,  7.6599e-03,\n",
       "           -5.2643e-04,  1.2159e-04],\n",
       "          [-4.7461e-01, -3.3008e-01,  4.3555e-01,  ...,  3.6523e-01,\n",
       "           -3.9551e-02, -4.7070e-01],\n",
       "          [-1.0107e-01,  6.5234e-01,  1.4062e-01,  ...,  3.7891e-01,\n",
       "           -4.0039e-02, -1.1719e-01],\n",
       "          ...,\n",
       "          [-6.0547e-01,  7.1289e-02,  2.9883e-01,  ..., -2.6758e-01,\n",
       "            9.2773e-02,  3.3936e-02],\n",
       "          [ 1.3281e+00, -2.1973e-01, -2.2656e-01,  ..., -2.5977e-01,\n",
       "           -5.5859e-01,  6.3281e-01],\n",
       "          [-9.1797e-02, -9.6680e-02, -1.1426e-01,  ...,  4.2773e-01,\n",
       "           -6.4941e-02,  8.6719e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.8799e-02, -2.9541e-02,  6.6528e-03,  ..., -5.1117e-04,\n",
       "            1.3855e-02, -3.3936e-02],\n",
       "          [-9.3750e-02,  2.1289e-01, -4.0039e-01,  ...,  4.1992e-02,\n",
       "            2.9297e-01,  3.1055e-01],\n",
       "          [-5.5176e-02,  5.6152e-02, -5.6885e-02,  ..., -2.6489e-02,\n",
       "            3.1250e-01, -1.1426e-01],\n",
       "          ...,\n",
       "          [ 5.3906e-01, -3.7891e-01,  1.2305e-01,  ..., -1.7773e-01,\n",
       "           -7.8613e-02, -2.4023e-01],\n",
       "          [ 5.2344e-01,  3.1738e-02, -3.4570e-01,  ...,  2.5757e-02,\n",
       "            4.0039e-02,  7.3438e-01],\n",
       "          [ 2.7734e-01, -5.1172e-01, -1.3672e-01,  ..., -3.1641e-01,\n",
       "            7.7148e-02,  1.5137e-01]],\n",
       "\n",
       "         [[-6.0120e-03, -3.4180e-03,  1.4267e-03,  ..., -1.0254e-02,\n",
       "           -4.3640e-03, -9.7046e-03],\n",
       "          [ 8.5449e-02,  5.0781e-02,  1.1230e-01,  ...,  4.6875e-01,\n",
       "            5.2490e-02, -5.3223e-02],\n",
       "          [ 2.5977e-01,  1.3672e-01,  1.5430e-01,  ..., -5.4688e-02,\n",
       "           -1.4258e-01, -2.4219e-01],\n",
       "          ...,\n",
       "          [-3.0273e-01,  2.2461e-02, -4.0039e-02,  ..., -1.3245e-02,\n",
       "           -6.1768e-02,  4.4336e-01],\n",
       "          [ 1.5625e-01,  7.2656e-01, -1.3672e-01,  ...,  2.0020e-01,\n",
       "           -6.5918e-02,  4.9414e-01],\n",
       "          [-7.3828e-01,  4.5508e-01,  5.9375e-01,  ...,  1.0010e-01,\n",
       "            9.6875e-01,  6.8750e-01]],\n",
       "\n",
       "         [[-4.0283e-02, -1.2207e-02,  2.0508e-02,  ...,  1.3611e-02,\n",
       "           -1.4038e-03, -6.3171e-03],\n",
       "          [ 3.4180e-01,  6.2500e-02,  3.6523e-01,  ...,  7.2754e-02,\n",
       "           -2.8516e-01, -1.8750e-01],\n",
       "          [ 1.2891e-01,  2.0996e-01,  1.4648e-01,  ..., -1.7090e-01,\n",
       "            3.3447e-02, -1.5430e-01],\n",
       "          ...,\n",
       "          [-2.1680e-01, -4.0820e-01,  1.3770e-01,  ..., -8.5449e-02,\n",
       "           -1.6699e-01, -1.7090e-01],\n",
       "          [-1.8359e-01, -5.6250e-01, -2.8906e-01,  ..., -2.8320e-01,\n",
       "           -2.6172e-01,  1.3379e-01],\n",
       "          [-6.7969e-01, -1.1133e-01,  4.7852e-01,  ..., -5.1953e-01,\n",
       "           -1.1621e-01, -7.2266e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 4.5898e-02,  4.6875e-02, -1.4709e-02,  ...,  1.8799e-02,\n",
       "            1.0938e-01, -1.2500e-01],\n",
       "          [ 1.8066e-01,  1.6992e-01,  5.5664e-02,  ..., -7.5000e-01,\n",
       "            5.3125e-01, -6.3672e-01],\n",
       "          [-1.7578e-01,  2.1680e-01, -2.3633e-01,  ..., -7.2656e-01,\n",
       "            3.7305e-01, -1.0469e+00],\n",
       "          ...,\n",
       "          [ 1.7969e-01,  1.0742e-01,  2.5586e-01,  ..., -3.3984e-01,\n",
       "            1.3984e+00,  1.0469e+00],\n",
       "          [-9.1797e-02, -3.2031e-01, -8.4766e-01,  ...,  1.8066e-01,\n",
       "            3.9453e-01,  1.6953e+00],\n",
       "          [ 1.9727e-01,  2.0996e-01, -4.6875e-01,  ..., -1.5156e+00,\n",
       "            1.0547e+00, -1.4258e-01]],\n",
       "\n",
       "         [[ 1.9653e-02, -1.3428e-03,  5.5847e-03,  ...,  6.3281e-01,\n",
       "           -3.2031e-01,  4.9219e-01],\n",
       "          [-5.1562e-01,  1.9141e-01, -1.7676e-01,  ...,  1.3203e+00,\n",
       "            8.2031e-02, -5.5078e-01],\n",
       "          [ 1.8750e-01,  4.0430e-01,  5.3516e-01,  ...,  1.0859e+00,\n",
       "           -1.2422e+00, -2.4805e-01],\n",
       "          ...,\n",
       "          [ 2.5391e-01,  9.3750e-02,  2.5391e-01,  ..., -3.5352e-01,\n",
       "           -2.6875e+00, -1.7344e+00],\n",
       "          [ 3.3789e-01,  2.2949e-01, -1.2256e-01,  ..., -1.4609e+00,\n",
       "           -1.7422e+00, -1.2891e+00],\n",
       "          [ 8.8379e-02, -9.7656e-02, -1.8164e-01,  ..., -3.3125e+00,\n",
       "           -1.0938e+00, -1.5312e+00]],\n",
       "\n",
       "         [[ 1.1902e-02,  4.4441e-04, -5.2795e-03,  ...,  7.3730e-02,\n",
       "            8.7109e-01, -5.5078e-01],\n",
       "          [-4.3945e-02, -1.9922e-01, -1.0791e-01,  ...,  4.6094e-01,\n",
       "           -4.6875e+00,  3.2656e+00],\n",
       "          [-7.1094e-01,  2.3242e-01, -2.1606e-02,  ...,  1.2812e+00,\n",
       "           -5.6250e+00,  1.7344e+00],\n",
       "          ...,\n",
       "          [ 2.6562e-01, -5.0000e-01,  3.9551e-02,  ..., -1.1035e-01,\n",
       "           -4.1250e+00,  1.1328e+00],\n",
       "          [ 3.5742e-01, -9.8633e-02, -7.8125e-02,  ..., -7.8125e-02,\n",
       "           -2.7188e+00, -9.5703e-01],\n",
       "          [-1.2305e-01,  3.4180e-01,  9.2773e-02,  ...,  1.5469e+00,\n",
       "           -3.1250e+00,  1.9727e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1118e-02,  6.0425e-03,  7.9346e-03,  ...,  4.9805e-01,\n",
       "            1.1875e+00,  5.0391e-01],\n",
       "          [ 5.8984e-01, -3.6719e-01, -4.1016e-01,  ..., -9.6875e-01,\n",
       "           -4.1406e-01, -7.1094e-01],\n",
       "          [ 1.1816e-01, -4.1748e-02,  3.3594e-01,  ...,  2.2095e-02,\n",
       "           -2.2812e+00, -1.3516e+00],\n",
       "          ...,\n",
       "          [-7.9688e-01, -1.4941e-01,  1.8066e-02,  ...,  5.2979e-02,\n",
       "            1.3965e-01, -1.0938e+00],\n",
       "          [ 9.5215e-02, -1.9629e-01, -4.4922e-01,  ...,  2.7734e-01,\n",
       "           -1.8672e+00, -7.4609e-01],\n",
       "          [ 5.3125e-01,  1.1963e-02, -6.8359e-03,  ...,  5.0781e-01,\n",
       "           -2.5312e+00, -6.7578e-01]],\n",
       "\n",
       "         [[ 2.4292e-02, -4.5654e-02,  4.9316e-02,  ...,  2.8906e-01,\n",
       "            2.8516e-01, -2.1875e+00],\n",
       "          [-1.8359e-01,  1.1963e-01, -6.8359e-03,  ..., -4.5898e-02,\n",
       "            1.3438e+00,  2.7500e+00],\n",
       "          [ 5.8594e-01, -7.9590e-02, -3.8086e-01,  ..., -8.3203e-01,\n",
       "            1.1875e+00,  3.7188e+00],\n",
       "          ...,\n",
       "          [ 9.7656e-02,  2.2168e-01,  5.1172e-01,  ..., -7.1875e-01,\n",
       "            1.3203e+00,  2.9062e+00],\n",
       "          [-5.8594e-01,  1.9629e-01, -1.9336e-01,  ..., -5.5078e-01,\n",
       "           -8.8672e-01,  3.0156e+00],\n",
       "          [ 2.1387e-01, -1.3477e-01,  4.8828e-02,  ..., -7.7344e-01,\n",
       "           -7.6953e-01,  2.7188e+00]],\n",
       "\n",
       "         [[-3.3203e-02,  2.3071e-02, -5.4321e-03,  ...,  3.2031e-01,\n",
       "           -2.1582e-01, -8.4766e-01],\n",
       "          [ 2.4902e-01, -4.1016e-02,  3.5547e-01,  ..., -1.3125e+00,\n",
       "            6.4844e-01,  1.2578e+00],\n",
       "          [-6.3672e-01,  5.2344e-01,  5.9814e-02,  ..., -2.1582e-01,\n",
       "           -3.3203e-01,  2.0156e+00],\n",
       "          ...,\n",
       "          [-9.2773e-02,  1.7383e-01, -7.4219e-02,  ..., -1.9141e-01,\n",
       "            1.0000e+00,  8.7109e-01],\n",
       "          [ 2.3145e-01,  3.3203e-01, -1.8750e-01,  ..., -1.3359e+00,\n",
       "            1.2500e+00,  1.7969e+00],\n",
       "          [ 1.7285e-01, -1.0234e+00, -4.9023e-01,  ...,  1.3770e-01,\n",
       "            3.3281e+00,  2.8125e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-1.7212e-02, -5.1880e-03,  1.0437e-02,  ...,  1.0681e-02,\n",
       "            2.6245e-02,  3.0884e-02],\n",
       "          [-3.0469e-01,  3.5742e-01,  8.3008e-02,  ..., -7.4219e-02,\n",
       "           -1.7871e-01, -1.5820e-01],\n",
       "          [-4.5117e-01,  4.6680e-01, -5.9082e-02,  ...,  1.4062e-01,\n",
       "           -5.7422e-01, -5.8594e-01],\n",
       "          ...,\n",
       "          [-1.5625e-01,  9.7656e-02,  3.0664e-01,  ..., -1.0254e-01,\n",
       "            2.2363e-01, -3.2031e-01],\n",
       "          [-3.5156e-01, -3.3398e-01,  1.9727e-01,  ...,  6.2500e-02,\n",
       "           -3.9844e-01,  1.8750e-01],\n",
       "          [ 3.9453e-01,  3.0078e-01,  4.4531e-01,  ..., -6.9922e-01,\n",
       "           -6.9141e-01,  1.1670e-01]],\n",
       "\n",
       "         [[ 5.8899e-03,  3.0640e-02, -1.3916e-02,  ...,  1.6016e-01,\n",
       "            1.0864e-02,  2.4567e-03],\n",
       "          [-1.5625e-01, -2.4414e-01,  2.3438e-01,  ..., -5.4688e-01,\n",
       "           -1.0498e-01, -5.7031e-01],\n",
       "          [-2.0898e-01,  1.3281e-01, -1.0791e-01,  ..., -1.9629e-01,\n",
       "            1.6113e-01,  5.9570e-02],\n",
       "          ...,\n",
       "          [-2.9492e-01, -2.2559e-01,  1.9531e-01,  ..., -3.3008e-01,\n",
       "           -9.0625e-01, -1.6211e-01],\n",
       "          [-1.3770e-01, -1.3379e-01,  3.4570e-01,  ..., -2.7539e-01,\n",
       "           -6.6797e-01,  1.1377e-01],\n",
       "          [ 1.9897e-02,  5.0391e-01,  4.7656e-01,  ..., -6.1279e-02,\n",
       "           -7.6953e-01, -1.3379e-01]],\n",
       "\n",
       "         [[-1.1673e-03, -1.1780e-02,  2.0142e-03,  ..., -8.6670e-03,\n",
       "           -3.3722e-03, -2.1515e-03],\n",
       "          [-3.6719e-01,  2.9492e-01, -3.5352e-01,  ...,  3.3594e-01,\n",
       "           -4.2578e-01, -1.2031e+00],\n",
       "          [-1.2207e-01,  4.1406e-01, -1.0498e-01,  ...,  7.8125e-01,\n",
       "           -1.2422e+00, -4.8047e-01],\n",
       "          ...,\n",
       "          [-4.1406e-01,  6.0547e-01, -2.2070e-01,  ...,  2.8809e-02,\n",
       "           -1.4355e-01, -7.5000e-01],\n",
       "          [ 2.2949e-01,  2.4805e-01,  1.1621e-01,  ...,  7.1875e-01,\n",
       "           -8.0078e-01, -2.2559e-01],\n",
       "          [-6.7578e-01,  2.6758e-01, -4.0234e-01,  ...,  1.7383e-01,\n",
       "           -1.0156e+00,  3.4332e-03]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-3.3379e-04,  7.7820e-03,  3.3264e-03,  ..., -3.1433e-03,\n",
       "            1.9165e-02, -3.2349e-03],\n",
       "          [-2.3242e-01, -4.0234e-01, -2.7344e-01,  ..., -4.4336e-01,\n",
       "            2.9688e-01,  6.2891e-01],\n",
       "          [ 3.9844e-01,  2.1680e-01,  9.4531e-01,  ..., -2.5977e-01,\n",
       "            1.6602e-01,  1.7773e-01],\n",
       "          ...,\n",
       "          [ 2.7344e-01, -8.9722e-03,  2.8711e-01,  ..., -9.8438e-01,\n",
       "           -1.0234e+00, -5.4297e-01],\n",
       "          [ 1.8555e-01,  6.1719e-01,  5.0049e-02,  ..., -5.3516e-01,\n",
       "           -7.7734e-01, -9.2578e-01],\n",
       "          [ 9.0942e-03,  4.5117e-01,  4.0234e-01,  ..., -1.1719e-01,\n",
       "           -8.3984e-01, -2.1094e-01]],\n",
       "\n",
       "         [[-1.5182e-03,  2.5787e-03,  4.1504e-03,  ...,  7.5531e-04,\n",
       "            3.5645e-02, -2.8442e-02],\n",
       "          [ 6.3477e-02,  3.1641e-01, -2.1973e-02,  ...,  1.7969e-01,\n",
       "           -3.3203e-01, -1.7871e-01],\n",
       "          [-1.0254e-01, -2.2363e-01, -3.4961e-01,  ...,  2.0898e-01,\n",
       "            1.2012e-01,  1.0254e-01],\n",
       "          ...,\n",
       "          [-2.8711e-01, -4.3359e-01, -3.1641e-01,  ...,  1.4062e-01,\n",
       "            2.2266e-01,  4.0430e-01],\n",
       "          [-7.9956e-03,  3.4424e-02, -4.6094e-01,  ...,  2.3926e-01,\n",
       "            3.1250e-01, -9.0820e-02],\n",
       "          [ 6.6016e-01, -2.9883e-01, -4.9414e-01,  ..., -2.0117e-01,\n",
       "            4.5166e-03, -1.5918e-01]],\n",
       "\n",
       "         [[ 2.8442e-02, -4.3701e-02,  2.7344e-02,  ..., -4.3213e-02,\n",
       "           -2.1118e-02,  2.2583e-02],\n",
       "          [ 1.3477e-01,  5.4297e-01,  3.3789e-01,  ...,  7.2266e-02,\n",
       "           -2.8711e-01,  1.9727e-01],\n",
       "          [-9.3262e-02,  4.5898e-01,  6.4062e-01,  ...,  1.2207e-01,\n",
       "            1.0010e-02,  1.8188e-02],\n",
       "          ...,\n",
       "          [-4.2969e-01, -1.9531e-01,  1.2589e-03,  ...,  5.1953e-01,\n",
       "            2.4512e-01, -1.3379e-01],\n",
       "          [-5.7031e-01, -9.7168e-02,  1.2695e-01,  ...,  8.5938e-02,\n",
       "            3.1641e-01, -3.1836e-01],\n",
       "          [-2.9785e-02, -3.3594e-01,  1.2695e-01,  ..., -1.7480e-01,\n",
       "            4.5312e-01, -3.0078e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 2.7954e-02, -2.2583e-02, -2.5787e-03,  ..., -2.7539e-01,\n",
       "           -3.2959e-02,  3.2617e-01],\n",
       "          [-1.3281e-01, -4.3555e-01, -2.1094e-01,  ..., -2.8320e-01,\n",
       "           -3.3594e-01, -2.7148e-01],\n",
       "          [-3.6523e-01,  1.9629e-01, -3.9258e-01,  ...,  4.9219e-01,\n",
       "           -4.1992e-01, -1.1094e+00],\n",
       "          ...,\n",
       "          [ 2.8516e-01, -7.0801e-02,  2.0801e-01,  ...,  8.0859e-01,\n",
       "            1.4219e+00, -3.6328e-01],\n",
       "          [ 1.6797e-01,  5.2344e-01,  1.4258e-01,  ...,  8.5547e-01,\n",
       "           -7.1289e-02, -2.5000e-01],\n",
       "          [ 2.8125e-01, -1.3184e-01, -3.4766e-01,  ..., -2.7222e-02,\n",
       "            1.3984e+00, -1.4453e+00]],\n",
       "\n",
       "         [[ 1.4404e-02, -1.0132e-02, -5.0659e-03,  ..., -8.9844e-02,\n",
       "           -1.9897e-02, -1.0010e-01],\n",
       "          [-1.2695e-01,  1.7090e-01,  2.1875e-01,  ...,  3.3447e-02,\n",
       "           -1.0938e+00,  9.0625e-01],\n",
       "          [-3.3203e-02,  1.5625e-02,  2.2168e-01,  ..., -6.7578e-01,\n",
       "            1.2578e+00,  1.5312e+00],\n",
       "          ...,\n",
       "          [-2.4805e-01, -1.5430e-01, -2.9883e-01,  ..., -1.2812e+00,\n",
       "            3.0859e-01,  1.3203e+00],\n",
       "          [-2.9688e-01,  9.5312e-01, -4.5312e-01,  ..., -5.8984e-01,\n",
       "            1.4688e+00,  2.0156e+00],\n",
       "          [ 2.0898e-01,  1.0156e-01, -5.8594e-01,  ...,  3.9062e-01,\n",
       "            7.8906e-01,  1.9922e+00]],\n",
       "\n",
       "         [[ 9.8267e-03, -1.1047e-02, -5.5542e-03,  ..., -6.4941e-02,\n",
       "            5.7983e-03, -9.8828e-01],\n",
       "          [-4.9609e-01, -4.8096e-02, -2.3535e-01,  ..., -2.7930e-01,\n",
       "           -1.6406e+00, -1.4375e+00],\n",
       "          [ 7.8125e-03,  5.1562e-01,  1.6211e-01,  ...,  1.6797e+00,\n",
       "            3.7305e-01, -5.6250e-01],\n",
       "          ...,\n",
       "          [ 1.8066e-01,  9.5215e-02, -6.1719e-01,  ...,  4.3555e-01,\n",
       "           -6.2891e-01, -4.4727e-01],\n",
       "          [ 3.2031e-01, -1.6602e-02,  7.8125e-01,  ...,  6.7188e-01,\n",
       "            7.1484e-01,  3.0625e+00],\n",
       "          [-6.4453e-01,  1.1797e+00,  3.3789e-01,  ...,  3.0625e+00,\n",
       "           -3.9453e-01,  3.7344e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-3.1738e-02, -1.7944e-02,  1.9531e-02,  ..., -1.9336e-01,\n",
       "           -2.3633e-01, -7.9297e-01],\n",
       "          [-8.8281e-01, -6.1719e-01, -3.7109e-01,  ..., -1.0986e-01,\n",
       "            1.0391e+00,  1.3594e+00],\n",
       "          [ 3.9648e-01, -6.4453e-01, -4.0430e-01,  ...,  1.0352e-01,\n",
       "            7.7344e-01,  1.2969e+00],\n",
       "          ...,\n",
       "          [ 3.1250e-01, -5.1172e-01, -3.8281e-01,  ...,  8.0859e-01,\n",
       "            1.5527e-01,  6.8750e-01],\n",
       "          [ 3.6377e-02, -8.6719e-01, -5.0781e-01,  ...,  9.5703e-01,\n",
       "            3.7695e-01,  1.4922e+00],\n",
       "          [ 2.6953e-01, -9.9219e-01,  3.8867e-01,  ...,  4.0234e-01,\n",
       "            1.4297e+00,  1.3359e+00]],\n",
       "\n",
       "         [[-1.8677e-02,  1.0071e-02, -2.6978e-02,  ...,  2.6562e-01,\n",
       "           -8.6719e-01, -5.4297e-01],\n",
       "          [ 5.4688e-02,  2.4805e-01,  4.0625e-01,  ..., -2.2969e+00,\n",
       "           -9.3359e-01, -2.3145e-01],\n",
       "          [-1.6992e-01,  6.5625e-01, -1.7188e-01,  ..., -2.1562e+00,\n",
       "            1.5234e+00,  7.8516e-01],\n",
       "          ...,\n",
       "          [ 3.1836e-01, -4.9414e-01, -2.6758e-01,  ..., -1.6875e+00,\n",
       "           -1.9727e-01, -1.2734e+00],\n",
       "          [ 1.2188e+00,  1.2207e-02,  6.7188e-01,  ..., -1.5703e+00,\n",
       "            2.0469e+00,  1.0781e+00],\n",
       "          [-2.7539e-01, -1.6016e-01, -1.2598e-01,  ..., -2.2188e+00,\n",
       "            4.5508e-01,  2.6758e-01]],\n",
       "\n",
       "         [[-1.0742e-02, -2.1240e-02, -3.3188e-04,  ..., -1.2329e-02,\n",
       "            1.7188e-01, -3.3398e-01],\n",
       "          [-4.3945e-02, -3.8477e-01,  6.7578e-01,  ..., -1.6875e+00,\n",
       "           -9.2578e-01, -2.5625e+00],\n",
       "          [-2.3438e-02, -2.3633e-01,  4.2773e-01,  ..., -7.3438e-01,\n",
       "           -1.4688e+00, -4.0312e+00],\n",
       "          ...,\n",
       "          [ 4.1406e-01,  6.5918e-03,  3.6133e-01,  ...,  2.8906e-01,\n",
       "           -9.5312e-01, -3.2969e+00],\n",
       "          [ 6.2500e-01, -1.6211e-01,  3.3447e-02,  ..., -5.6250e-01,\n",
       "           -7.7734e-01, -8.7109e-01],\n",
       "          [-4.6484e-01, -1.4062e-01,  5.0781e-01,  ...,  1.5820e-01,\n",
       "           -5.1562e-01, -6.0938e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-2.3071e-02,  4.9072e-02,  5.0293e-02,  ...,  7.1411e-03,\n",
       "           -4.2725e-03, -1.6968e-02],\n",
       "          [-1.5381e-02, -3.4570e-01, -3.1055e-01,  ...,  4.2480e-02,\n",
       "           -4.4141e-01,  2.0020e-01],\n",
       "          [-5.5469e-01, -3.0859e-01,  1.6895e-01,  ..., -9.6094e-01,\n",
       "           -3.2031e-01,  1.7969e-01],\n",
       "          ...,\n",
       "          [-1.6504e-01, -4.0039e-01,  2.7539e-01,  ..., -3.7500e-01,\n",
       "           -3.4668e-02,  3.5889e-02],\n",
       "          [-3.6328e-01, -5.0293e-02,  3.6377e-02,  ..., -8.7891e-02,\n",
       "            2.3828e-01,  2.2852e-01],\n",
       "          [ 6.2500e-02,  2.1240e-02,  4.3945e-02,  ..., -1.3770e-01,\n",
       "           -1.1426e-01, -2.4121e-01]],\n",
       "\n",
       "         [[-1.7395e-03, -5.9204e-03, -1.4648e-02,  ...,  2.9541e-02,\n",
       "           -7.6675e-04, -3.5477e-04],\n",
       "          [ 1.4551e-01,  2.2461e-02, -4.9805e-02,  ..., -6.1328e-01,\n",
       "            1.8750e-01, -2.3340e-01],\n",
       "          [-3.9551e-02,  2.6172e-01,  2.8516e-01,  ..., -3.6523e-01,\n",
       "            7.0312e-01,  2.9785e-02],\n",
       "          ...,\n",
       "          [ 9.6191e-02, -1.0010e-01,  8.8867e-02,  ...,  2.4658e-02,\n",
       "            1.3379e-01, -3.8867e-01],\n",
       "          [ 4.0820e-01, -1.0132e-02, -2.1289e-01,  ..., -3.5547e-01,\n",
       "            1.2891e-01,  1.1182e-01],\n",
       "          [ 2.0898e-01, -5.5469e-01, -8.3008e-02,  ..., -9.8438e-01,\n",
       "            6.2109e-01, -1.1475e-01]],\n",
       "\n",
       "         [[-2.6611e-02,  5.6152e-03,  1.2512e-02,  ..., -1.5625e-02,\n",
       "            1.7090e-02,  1.5625e-02],\n",
       "          [ 1.1719e-01, -3.2422e-01, -1.6016e-01,  ...,  1.7871e-01,\n",
       "           -2.3340e-01,  2.6562e-01],\n",
       "          [ 2.2559e-01, -3.7500e-01, -2.0215e-01,  ..., -6.5918e-02,\n",
       "           -3.3984e-01,  4.7070e-01],\n",
       "          ...,\n",
       "          [ 2.4121e-01, -1.6504e-01,  4.1992e-01,  ..., -4.1797e-01,\n",
       "           -3.9258e-01, -1.1377e-01],\n",
       "          [ 3.1128e-02,  6.6406e-02,  3.1836e-01,  ..., -3.1836e-01,\n",
       "           -2.1875e-01, -4.3164e-01],\n",
       "          [-2.2461e-01,  6.6797e-01, -4.6997e-03,  ..., -2.0703e-01,\n",
       "           -6.9336e-02,  2.8320e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.6377e-02,  2.3438e-02,  1.3428e-02,  ...,  7.6294e-03,\n",
       "            2.2461e-02, -5.2002e-02],\n",
       "          [-8.8379e-02,  7.3047e-01,  1.8066e-01,  ..., -1.1865e-01,\n",
       "            9.9121e-02,  2.3438e-01],\n",
       "          [-5.2344e-01, -2.4316e-01, -4.2578e-01,  ..., -7.2754e-02,\n",
       "            2.2754e-01,  6.6797e-01],\n",
       "          ...,\n",
       "          [ 1.5039e-01,  5.9326e-02, -4.2969e-01,  ...,  2.5391e-01,\n",
       "           -1.5918e-01,  4.3359e-01],\n",
       "          [ 4.2725e-02,  1.6211e-01, -3.5547e-01,  ..., -1.5234e-01,\n",
       "           -3.4424e-02,  1.4941e-01],\n",
       "          [ 2.9688e-01, -4.9316e-02, -4.6484e-01,  ..., -1.9336e-01,\n",
       "           -3.9648e-01, -5.1562e-01]],\n",
       "\n",
       "         [[ 1.9409e-02, -2.1973e-02, -9.1553e-03,  ...,  3.6011e-03,\n",
       "            2.2095e-02,  4.5166e-03],\n",
       "          [ 3.2812e-01, -1.4648e-01,  5.3906e-01,  ..., -3.7305e-01,\n",
       "            4.0234e-01,  5.0781e-01],\n",
       "          [ 4.0820e-01, -4.6680e-01,  3.4180e-01,  ...,  1.7383e-01,\n",
       "            4.1406e-01,  2.9492e-01],\n",
       "          ...,\n",
       "          [ 4.5508e-01, -5.0781e-01,  4.7852e-01,  ...,  3.5352e-01,\n",
       "            2.7539e-01, -1.1035e-01],\n",
       "          [ 2.1191e-01, -5.4297e-01,  4.2773e-01,  ...,  6.6797e-01,\n",
       "            1.2305e-01,  7.5684e-02],\n",
       "          [ 2.6611e-02, -2.1680e-01,  1.7480e-01,  ...,  1.5527e-01,\n",
       "           -2.2070e-01,  9.2285e-02]],\n",
       "\n",
       "         [[-4.3945e-02, -1.4465e-02, -1.4954e-02,  ..., -1.3611e-02,\n",
       "            1.4954e-02,  6.3782e-03],\n",
       "          [-8.3008e-02, -1.7090e-01,  2.4707e-01,  ...,  2.8125e-01,\n",
       "           -1.7285e-01,  1.9824e-01],\n",
       "          [ 3.5352e-01, -8.2422e-01,  6.4453e-01,  ...,  5.8594e-02,\n",
       "           -2.8076e-02,  1.0681e-02],\n",
       "          ...,\n",
       "          [ 2.8125e-01, -1.8457e-01,  6.0156e-01,  ...,  6.6406e-01,\n",
       "           -3.5547e-01, -3.9062e-01],\n",
       "          [-7.4219e-02, -1.1353e-02,  9.2188e-01,  ...,  3.7891e-01,\n",
       "           -6.4062e-01, -8.4766e-01],\n",
       "          [-3.5938e-01, -8.6914e-02,  6.5234e-01,  ...,  5.1562e-01,\n",
       "            6.1328e-01, -3.0469e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.8910e-03,  1.9409e-02,  1.7456e-02,  ..., -2.2852e-01,\n",
       "            1.4746e-01, -1.0840e-01],\n",
       "          [ 8.9844e-01,  3.2227e-01,  5.0000e-01,  ..., -8.9453e-01,\n",
       "           -5.9375e-01,  4.4141e-01],\n",
       "          [-1.7578e-02, -1.8750e-01,  5.1270e-02,  ..., -3.8086e-01,\n",
       "           -1.7090e-01,  6.7188e-01],\n",
       "          ...,\n",
       "          [ 4.3359e-01,  5.3906e-01, -4.1016e-01,  ..., -4.8438e-01,\n",
       "            1.1641e+00, -1.3281e-01],\n",
       "          [ 1.1016e+00, -1.6602e-02,  5.3516e-01,  ...,  2.5977e-01,\n",
       "           -6.1279e-02,  1.2578e+00],\n",
       "          [ 7.8125e-01, -9.1406e-01, -1.5918e-01,  ..., -3.0469e-01,\n",
       "           -5.6885e-02,  6.1719e-01]],\n",
       "\n",
       "         [[-2.0264e-02,  5.3406e-03,  6.2256e-03,  ..., -2.2031e+00,\n",
       "            2.5391e-01,  4.8242e-01],\n",
       "          [-2.3340e-01,  4.7461e-01, -3.2031e-01,  ...,  2.1719e+00,\n",
       "           -2.8125e-01, -2.5781e+00],\n",
       "          [ 3.2227e-02,  4.2383e-01,  4.5312e-01,  ...,  3.7969e+00,\n",
       "           -1.4375e+00, -3.4219e+00],\n",
       "          ...,\n",
       "          [-4.8828e-04, -2.6978e-02,  4.9609e-01,  ...,  2.8594e+00,\n",
       "           -1.9141e+00, -2.5625e+00],\n",
       "          [-4.5117e-01, -2.8125e-01,  1.2695e-02,  ...,  4.8750e+00,\n",
       "           -2.5781e+00, -3.1406e+00],\n",
       "          [-2.6953e-01, -2.6953e-01, -3.4570e-01,  ...,  3.2969e+00,\n",
       "           -2.3281e+00, -2.4219e+00]],\n",
       "\n",
       "         [[-5.2795e-03, -3.9673e-03,  3.2959e-02,  ...,  6.6406e-02,\n",
       "           -5.2002e-02, -1.9434e-01],\n",
       "          [-4.7852e-02, -7.1875e-01,  8.7891e-02,  ...,  1.3281e-01,\n",
       "           -2.4531e+00,  9.4922e-01],\n",
       "          [ 2.3730e-01, -1.1719e-01, -2.9053e-02,  ..., -2.0898e-01,\n",
       "           -2.2969e+00,  9.5367e-04],\n",
       "          ...,\n",
       "          [ 3.1641e-01,  4.2773e-01, -6.9141e-01,  ...,  1.0391e+00,\n",
       "            1.3359e+00, -1.1406e+00],\n",
       "          [ 6.6797e-01,  3.7500e-01,  6.7188e-01,  ..., -4.4922e-01,\n",
       "            7.1094e-01, -1.6719e+00],\n",
       "          [ 2.1387e-01, -5.5859e-01,  2.9785e-02,  ..., -1.1875e+00,\n",
       "           -1.6504e-01, -6.5234e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5503e-02, -8.9722e-03, -1.7578e-02,  ..., -1.8281e+00,\n",
       "            2.2031e+00, -6.2109e-01],\n",
       "          [-2.7344e-02, -1.8652e-01,  4.4922e-02,  ...,  1.7344e+00,\n",
       "           -9.6094e-01,  2.4062e+00],\n",
       "          [ 3.1494e-02,  5.7617e-02,  3.0029e-02,  ...,  3.8438e+00,\n",
       "           -3.8125e+00,  2.4688e+00],\n",
       "          ...,\n",
       "          [-4.4141e-01,  3.2227e-02, -6.1035e-02,  ...,  3.3906e+00,\n",
       "           -1.8047e+00,  1.3828e+00],\n",
       "          [ 2.1387e-01,  4.0625e-01,  1.0254e-02,  ...,  5.0625e+00,\n",
       "           -3.6250e+00,  1.0312e+00],\n",
       "          [-2.2656e-01,  2.2852e-01, -4.6484e-01,  ...,  5.0312e+00,\n",
       "           -1.1172e+00, -6.4062e-01]],\n",
       "\n",
       "         [[-2.4658e-02,  1.1475e-02, -6.4697e-03,  ..., -1.4062e-01,\n",
       "            1.7285e-01, -7.4219e-01],\n",
       "          [ 1.0645e-01, -4.4727e-01, -2.1680e-01,  ..., -3.5156e-02,\n",
       "           -1.3672e-01,  2.0469e+00],\n",
       "          [-2.5977e-01, -1.0352e-01, -5.3467e-02,  ...,  2.5000e-01,\n",
       "           -6.5625e-01,  2.7344e+00],\n",
       "          ...,\n",
       "          [-2.9297e-03,  4.6875e-01, -2.3535e-01,  ...,  6.9141e-01,\n",
       "            1.0156e+00,  1.3203e+00],\n",
       "          [-2.2266e-01,  6.5234e-01, -3.4180e-02,  ...,  6.6406e-01,\n",
       "           -6.9824e-02,  1.2656e+00],\n",
       "          [-5.3906e-01, -3.5547e-01,  9.4727e-02,  ..., -2.2070e-01,\n",
       "           -1.3438e+00, -1.8799e-02]],\n",
       "\n",
       "         [[ 1.0925e-02,  1.1841e-02,  7.8125e-03,  ..., -2.7588e-02,\n",
       "           -1.4844e-01,  5.4932e-02],\n",
       "          [-2.6367e-01, -6.9336e-02, -4.2969e-01,  ...,  4.5117e-01,\n",
       "            1.3438e+00, -1.6016e+00],\n",
       "          [ 6.8750e-01,  1.2793e-01,  6.7188e-01,  ...,  2.1118e-02,\n",
       "            9.2969e-01,  1.7969e-01],\n",
       "          ...,\n",
       "          [-1.2891e+00,  1.2344e+00,  7.5781e-01,  ...,  4.0430e-01,\n",
       "            5.8203e-01,  7.1094e-01],\n",
       "          [-8.9844e-01,  8.5938e-02,  2.0801e-01,  ..., -8.3594e-01,\n",
       "            5.3906e-01,  1.0156e+00],\n",
       "          [ 1.9727e-01, -4.2578e-01,  4.0234e-01,  ..., -1.9062e+00,\n",
       "            3.8867e-01, -2.0996e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 4.4434e-02,  9.8877e-03,  1.0803e-02,  ...,  9.8419e-04,\n",
       "           -2.1515e-03, -6.1523e-02],\n",
       "          [ 2.6367e-01, -1.7480e-01,  1.2158e-01,  ..., -4.1406e-01,\n",
       "            2.5586e-01,  2.6562e-01],\n",
       "          [ 3.5547e-01,  5.7031e-01, -3.0151e-02,  ...,  4.8828e-02,\n",
       "           -4.9072e-02,  5.6763e-03],\n",
       "          ...,\n",
       "          [ 6.5234e-01,  1.1719e-01, -1.8066e-01,  ..., -3.0396e-02,\n",
       "           -3.8574e-02,  6.6406e-01],\n",
       "          [ 3.8477e-01, -6.8359e-02,  1.5820e-01,  ..., -2.8516e-01,\n",
       "           -3.2227e-01,  6.4062e-01],\n",
       "          [ 3.2031e-01,  9.3994e-03,  7.8125e-03,  ...,  2.5391e-01,\n",
       "            7.8906e-01, -2.5000e-01]],\n",
       "\n",
       "         [[ 2.7161e-03, -2.6245e-03,  7.4768e-03,  ..., -6.3705e-04,\n",
       "            3.2806e-03,  3.7842e-03],\n",
       "          [ 1.4954e-02, -1.3086e-01, -6.4453e-02,  ..., -4.3945e-02,\n",
       "           -1.0840e-01,  5.1953e-01],\n",
       "          [ 1.9336e-01, -6.9531e-01, -6.7969e-01,  ...,  9.1309e-02,\n",
       "           -4.4922e-01,  3.2422e-01],\n",
       "          ...,\n",
       "          [-2.7222e-02,  9.9609e-01,  1.1084e-01,  ..., -3.6523e-01,\n",
       "           -5.6152e-02, -2.9688e-01],\n",
       "          [ 8.3496e-02,  4.3359e-01,  3.1055e-01,  ..., -1.4746e-01,\n",
       "            4.2969e-01, -2.1094e-01],\n",
       "          [ 1.3281e-01,  8.4766e-01,  4.0625e-01,  ..., -3.1445e-01,\n",
       "           -8.1250e-01, -6.3281e-01]],\n",
       "\n",
       "         [[ 7.8735e-03, -4.8218e-03, -2.4719e-03,  ...,  1.5503e-02,\n",
       "            5.1575e-03, -1.1215e-03],\n",
       "          [-8.0566e-02, -1.2793e-01,  2.0801e-01,  ...,  4.2578e-01,\n",
       "            5.2002e-02,  1.0156e-01],\n",
       "          [-1.2451e-01,  2.6953e-01,  4.3945e-01,  ...,  6.2891e-01,\n",
       "           -2.1729e-02,  1.7773e-01],\n",
       "          ...,\n",
       "          [ 2.2949e-01, -3.3203e-02, -5.1562e-01,  ...,  9.6680e-02,\n",
       "            2.6758e-01,  3.2422e-01],\n",
       "          [-4.6387e-02, -1.6113e-01, -3.5352e-01,  ..., -7.5195e-02,\n",
       "           -1.6113e-01, -2.8516e-01],\n",
       "          [-2.0312e-01,  2.6172e-01, -1.3477e-01,  ...,  6.6016e-01,\n",
       "            2.6172e-01,  6.9141e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.7212e-02, -2.4414e-02, -8.5449e-03,  ...,  7.2937e-03,\n",
       "           -3.1250e-02,  2.9907e-02],\n",
       "          [ 3.7598e-02,  2.1582e-01,  3.6719e-01,  ..., -2.1875e-01,\n",
       "            3.5938e-01,  1.1133e-01],\n",
       "          [ 1.5527e-01,  1.1621e-01, -3.1055e-01,  ..., -2.7148e-01,\n",
       "           -2.9907e-03,  2.6758e-01],\n",
       "          ...,\n",
       "          [ 2.0117e-01,  3.4912e-02,  1.3184e-01,  ..., -2.1851e-02,\n",
       "           -1.5430e-01,  9.0625e-01],\n",
       "          [ 4.3945e-01, -1.5039e-01,  1.3184e-01,  ..., -1.5625e-01,\n",
       "           -4.1602e-01,  6.6016e-01],\n",
       "          [ 2.5781e-01,  6.5918e-02,  2.1387e-01,  ..., -3.4180e-01,\n",
       "           -5.6250e-01,  9.4922e-01]],\n",
       "\n",
       "         [[-3.0823e-03, -2.5513e-02,  9.4604e-03,  ..., -2.1362e-02,\n",
       "           -5.3711e-03, -2.2461e-02],\n",
       "          [-7.6172e-01,  1.3184e-01,  3.4180e-01,  ...,  2.1191e-01,\n",
       "           -4.1211e-01,  2.8516e-01],\n",
       "          [-2.1680e-01,  3.9453e-01,  6.5234e-01,  ..., -3.8281e-01,\n",
       "            5.5469e-01, -1.1963e-01],\n",
       "          ...,\n",
       "          [ 6.9824e-02, -3.0078e-01, -1.4355e-01,  ..., -2.1680e-01,\n",
       "            1.4746e-01, -1.2109e-01],\n",
       "          [-1.5137e-01, -1.5625e-01,  1.5039e-01,  ..., -4.2578e-01,\n",
       "            4.8218e-03,  6.4453e-02],\n",
       "          [-1.1094e+00,  6.8359e-01,  3.7109e-01,  ...,  7.6172e-02,\n",
       "           -2.8711e-01, -2.2339e-02]],\n",
       "\n",
       "         [[ 3.4668e-02,  1.3809e-03, -9.1553e-03,  ...,  5.7068e-03,\n",
       "            8.9111e-03,  1.6113e-02],\n",
       "          [ 5.7068e-03,  1.4746e-01,  1.6797e-01,  ..., -3.9844e-01,\n",
       "           -2.1973e-01, -1.8311e-02],\n",
       "          [ 3.8867e-01,  4.6094e-01, -6.8750e-01,  ..., -1.2793e-01,\n",
       "           -2.3926e-01,  1.0156e-01],\n",
       "          ...,\n",
       "          [-6.3965e-02, -5.4932e-02,  1.9043e-01,  ...,  4.4434e-02,\n",
       "           -3.3008e-01, -1.8677e-02],\n",
       "          [ 5.3516e-01,  6.9531e-01, -3.4375e-01,  ..., -3.5938e-01,\n",
       "           -6.8848e-02, -2.7539e-01],\n",
       "          [-4.9219e-01,  7.9102e-02, -2.3926e-01,  ..., -1.7383e-01,\n",
       "            1.7212e-02, -3.6523e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.4709e-02, -5.3711e-03, -2.3193e-02,  ..., -1.2451e-01,\n",
       "           -6.5918e-02, -9.0332e-02],\n",
       "          [-2.6758e-01, -1.2598e-01, -1.3184e-02,  ...,  3.8672e-01,\n",
       "           -1.3203e+00,  3.3750e+00],\n",
       "          [ 2.0996e-02,  3.3984e-01,  2.0801e-01,  ...,  3.1055e-01,\n",
       "           -7.5781e-01,  2.4062e+00],\n",
       "          ...,\n",
       "          [ 1.3086e-01, -4.1992e-01,  3.8477e-01,  ..., -8.1250e-01,\n",
       "            9.2188e-01,  2.6719e+00],\n",
       "          [ 1.2256e-01, -8.8281e-01, -2.0312e-01,  ...,  5.6250e-01,\n",
       "            2.6953e-01,  2.6719e+00],\n",
       "          [ 1.7578e-01, -7.8125e-01, -2.0605e-01,  ...,  7.7344e-01,\n",
       "            8.7500e-01,  1.0469e+00]],\n",
       "\n",
       "         [[ 1.2329e-02, -8.1177e-03, -4.3640e-03,  ..., -7.2754e-02,\n",
       "           -4.9561e-02, -1.3965e-01],\n",
       "          [-1.3867e-01,  3.7891e-01,  2.0996e-01,  ..., -6.4062e-01,\n",
       "            2.4062e+00, -2.3281e+00],\n",
       "          [-7.1094e-01,  1.3906e+00, -5.3906e-01,  ..., -2.8320e-01,\n",
       "            2.2188e+00, -1.6875e+00],\n",
       "          ...,\n",
       "          [-3.3203e-02, -1.9434e-01, -3.7109e-02,  ...,  5.4297e-01,\n",
       "            1.5391e+00, -1.5078e+00],\n",
       "          [ 2.0898e-01,  4.5312e-01, -8.7109e-01,  ...,  2.6953e-01,\n",
       "           -4.5654e-02, -1.2422e+00],\n",
       "          [-2.5195e-01, -9.2773e-02, -1.6406e-01,  ...,  2.5391e-01,\n",
       "            1.3047e+00, -1.3672e+00]],\n",
       "\n",
       "         [[ 2.9663e-02,  5.6641e-02, -1.5488e-03,  ...,  3.9453e-01,\n",
       "           -5.6641e-01, -5.8350e-02],\n",
       "          [ 3.7891e-01,  2.1484e-01, -7.1289e-02,  ..., -1.8750e-01,\n",
       "            2.7930e-01,  2.9688e-01],\n",
       "          [-3.2422e-01,  2.3438e-01, -1.4375e+00,  ...,  4.4336e-01,\n",
       "            5.4688e-02, -2.0508e-01],\n",
       "          ...,\n",
       "          [ 4.4531e-01, -2.6172e-01, -1.8945e-01,  ..., -1.2891e-01,\n",
       "            1.7266e+00, -1.2109e+00],\n",
       "          [ 4.3750e-01,  1.1865e-01, -9.7656e-03,  ...,  7.5781e-01,\n",
       "            2.4688e+00, -8.2422e-01],\n",
       "          [ 2.3633e-01, -8.0566e-02, -7.2266e-02,  ..., -9.4238e-02,\n",
       "            1.7188e-01,  3.5156e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 4.0527e-02, -1.9287e-02, -1.7700e-02,  ...,  4.7852e-02,\n",
       "           -4.3164e-01,  1.9141e-01],\n",
       "          [-4.9023e-01, -6.3672e-01,  6.1719e-01,  ...,  1.3125e+00,\n",
       "            5.1562e-01, -8.2422e-01],\n",
       "          [ 1.9531e-03,  2.8906e-01, -3.4961e-01,  ...,  1.3750e+00,\n",
       "            4.6875e-01, -7.5781e-01],\n",
       "          ...,\n",
       "          [ 5.8203e-01, -1.1816e-01,  7.1875e-01,  ...,  7.7344e-01,\n",
       "            2.6406e+00, -1.0312e+00],\n",
       "          [ 3.9307e-02,  9.7656e-04,  3.7500e-01,  ...,  8.9844e-01,\n",
       "            1.7109e+00, -1.0596e-01],\n",
       "          [ 1.0742e-01, -3.0078e-01, -2.9785e-02,  ..., -1.9727e-01,\n",
       "            2.8594e+00,  8.3203e-01]],\n",
       "\n",
       "         [[-2.5787e-03,  2.0142e-02,  2.4048e-02,  ...,  7.2266e-01,\n",
       "           -2.2363e-01, -2.5000e+00],\n",
       "          [ 7.5391e-01, -3.2031e-01, -4.3213e-02,  ..., -2.0215e-01,\n",
       "            7.1094e-01,  7.3438e+00],\n",
       "          [ 6.7188e-01, -9.3750e-02,  1.6309e-01,  ..., -6.1328e-01,\n",
       "            2.9297e-01,  9.8125e+00],\n",
       "          ...,\n",
       "          [-3.0273e-01,  1.2891e-01,  7.7344e-01,  ..., -5.6250e-01,\n",
       "            1.4453e+00,  1.0375e+01],\n",
       "          [-2.7734e-01, -4.2773e-01, -2.2070e-01,  ..., -5.6641e-02,\n",
       "           -3.1055e-01,  1.0312e+01],\n",
       "          [ 1.9336e-01,  3.0078e-01,  4.5410e-02,  ..., -2.3594e+00,\n",
       "           -2.6953e-01,  7.8750e+00]],\n",
       "\n",
       "         [[-3.3691e-02,  2.7588e-02, -1.8188e-02,  ..., -2.9297e-01,\n",
       "           -1.6797e-01,  2.5586e-01],\n",
       "          [ 1.0352e-01,  3.2617e-01,  1.1621e-01,  ...,  2.1875e+00,\n",
       "            1.6953e+00,  1.1797e+00],\n",
       "          [-3.4375e-01,  4.6484e-01,  2.5391e-01,  ...,  3.2812e+00,\n",
       "            2.0312e+00,  1.7266e+00],\n",
       "          ...,\n",
       "          [-3.8281e-01, -2.4512e-01,  1.2158e-01,  ...,  2.1406e+00,\n",
       "            1.3047e+00,  3.2812e+00],\n",
       "          [-2.2168e-01, -3.3008e-01,  5.7812e-01,  ...,  1.1875e+00,\n",
       "            6.3281e-01,  2.8125e+00],\n",
       "          [-8.9844e-02, -1.7773e-01, -2.7734e-01,  ...,  2.8594e+00,\n",
       "           -3.7109e-01,  5.6250e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 4.9472e-06,  7.6599e-03, -1.0254e-02,  ...,  1.5198e-02,\n",
       "           -7.7515e-03, -2.0142e-02],\n",
       "          [-1.2793e-01,  2.0508e-01,  1.9336e-01,  ..., -5.1953e-01,\n",
       "           -3.3984e-01,  5.0391e-01],\n",
       "          [ 3.1006e-02, -3.0078e-01,  3.0078e-01,  ...,  5.6152e-02,\n",
       "           -3.5156e-01,  4.3359e-01],\n",
       "          ...,\n",
       "          [ 6.3281e-01, -7.6172e-02,  9.9219e-01,  ..., -4.2383e-01,\n",
       "            2.9297e-01, -6.0303e-02],\n",
       "          [ 2.4316e-01,  1.5137e-02,  5.3906e-01,  ..., -5.1172e-01,\n",
       "            2.1387e-01, -5.9814e-02],\n",
       "          [ 2.5195e-01, -3.6328e-01,  1.5723e-01,  ..., -7.8516e-01,\n",
       "            1.9727e-01, -2.4609e-01]],\n",
       "\n",
       "         [[-6.9275e-03, -2.9175e-02, -3.6133e-02,  ...,  5.0537e-02,\n",
       "           -5.7373e-03, -1.0193e-02],\n",
       "          [-4.0039e-01,  8.8379e-02, -5.3906e-01,  ..., -5.5469e-01,\n",
       "            5.1514e-02,  4.8047e-01],\n",
       "          [ 1.8066e-01,  5.3125e-01, -1.4746e-01,  ..., -9.8145e-02,\n",
       "            7.7344e-01,  6.6016e-01],\n",
       "          ...,\n",
       "          [ 8.6719e-01,  2.7710e-02,  7.3828e-01,  ...,  6.2256e-02,\n",
       "           -1.6113e-01,  2.7539e-01],\n",
       "          [ 1.6602e-02, -3.5858e-03,  2.9102e-01,  ..., -5.7617e-02,\n",
       "           -2.1289e-01,  2.8516e-01],\n",
       "          [-2.2461e-01, -4.2383e-01,  1.4609e+00,  ...,  4.7266e-01,\n",
       "            8.0859e-01,  3.4570e-01]],\n",
       "\n",
       "         [[-5.8105e-02, -5.0659e-03, -1.3733e-02,  ...,  3.2959e-02,\n",
       "            2.9541e-02,  1.0681e-02],\n",
       "          [-8.7402e-02,  3.9062e-01,  2.8711e-01,  ..., -4.8242e-01,\n",
       "            8.3984e-02, -6.4392e-03],\n",
       "          [ 4.4922e-01,  5.1172e-01,  1.7578e-01,  ..., -5.9766e-01,\n",
       "           -3.1250e-01, -1.1523e-01],\n",
       "          ...,\n",
       "          [-1.8848e-01, -5.7812e-01, -2.4512e-01,  ..., -2.5977e-01,\n",
       "           -2.9883e-01, -8.5938e-01],\n",
       "          [-2.1582e-01, -4.4141e-01,  8.5449e-03,  ..., -1.0645e-01,\n",
       "           -1.1816e-01, -9.3750e-02],\n",
       "          [ 5.8203e-01, -2.0020e-01,  1.1914e-01,  ..., -2.7100e-02,\n",
       "           -3.2422e-01, -8.2422e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.1951e-03,  7.2937e-03,  6.7383e-02,  ...,  1.6235e-02,\n",
       "           -1.7334e-02, -2.3071e-02],\n",
       "          [-1.0449e-01,  1.7578e-01, -5.6641e-01,  ..., -3.2031e-01,\n",
       "            2.2461e-01,  4.4141e-01],\n",
       "          [-6.7871e-02, -1.7334e-02, -2.4805e-01,  ..., -4.9805e-01,\n",
       "           -7.6172e-01,  5.3516e-01],\n",
       "          ...,\n",
       "          [ 3.0469e-01, -5.0000e-01,  4.0820e-01,  ..., -5.9375e-01,\n",
       "           -3.0078e-01,  8.6328e-01],\n",
       "          [ 1.4160e-01, -4.2969e-01,  1.1914e-01,  ..., -4.4141e-01,\n",
       "            9.2773e-02,  1.2598e-01],\n",
       "          [ 4.2578e-01, -1.1133e-01,  5.0000e-01,  ..., -6.1768e-02,\n",
       "           -6.5918e-02,  1.7969e-01]],\n",
       "\n",
       "         [[-1.6724e-02, -1.4648e-02, -7.6904e-03,  ..., -1.0071e-02,\n",
       "            5.9204e-03, -6.3477e-03],\n",
       "          [ 4.5312e-01,  1.5723e-01,  2.9688e-01,  ..., -1.0547e+00,\n",
       "           -2.5977e-01,  2.9297e-01],\n",
       "          [ 9.9219e-01,  1.0859e+00,  3.7354e-02,  ..., -9.1406e-01,\n",
       "           -2.8442e-02,  3.7500e-01],\n",
       "          ...,\n",
       "          [-8.0078e-02,  8.2422e-01, -2.9297e-01,  ..., -3.0078e-01,\n",
       "           -2.2266e-01, -6.5430e-02],\n",
       "          [ 2.3047e-01,  5.3125e-01,  3.4424e-02,  ..., -1.8457e-01,\n",
       "            3.7305e-01, -2.5781e-01],\n",
       "          [ 2.5586e-01,  8.6328e-01,  4.0771e-02,  ..., -1.9141e+00,\n",
       "            8.3984e-01, -4.6289e-01]],\n",
       "\n",
       "         [[ 1.1963e-02, -4.6997e-03,  4.4556e-03,  ...,  1.3550e-02,\n",
       "            3.0762e-02,  6.9275e-03],\n",
       "          [-4.3945e-01, -6.0156e-01, -4.3945e-01,  ..., -2.0996e-01,\n",
       "           -2.9102e-01, -6.8848e-02],\n",
       "          [-5.7422e-01, -7.4219e-01, -4.8633e-01,  ..., -3.7305e-01,\n",
       "           -3.2031e-01, -4.3359e-01],\n",
       "          ...,\n",
       "          [ 1.8262e-01, -2.3242e-01,  3.2812e-01,  ...,  5.8594e-01,\n",
       "           -8.3203e-01, -3.2812e-01],\n",
       "          [-2.4512e-01, -6.2109e-01,  4.4922e-01,  ...,  6.2500e-02,\n",
       "           -4.9219e-01, -7.9102e-02],\n",
       "          [-1.8262e-01, -4.1406e-01, -4.9805e-02,  ...,  1.8750e-01,\n",
       "           -1.2344e+00,  8.7402e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 3.5400e-02, -1.6357e-02,  1.9287e-02,  ..., -3.8672e-01,\n",
       "            2.6562e-01, -1.7480e-01],\n",
       "          [ 4.1504e-02,  2.0605e-01,  4.1797e-01,  ..., -1.3281e+00,\n",
       "           -5.4688e-01,  3.5742e-01],\n",
       "          [-3.3008e-01, -1.1406e+00,  2.3340e-01,  ..., -1.1172e+00,\n",
       "           -1.7344e+00,  1.9922e-01],\n",
       "          ...,\n",
       "          [-5.1172e-01,  6.6016e-01, -2.6562e-01,  ..., -2.5391e-01,\n",
       "           -1.4922e+00,  8.4766e-01],\n",
       "          [-2.5391e-01,  2.9883e-01,  1.7676e-01,  ...,  1.0625e+00,\n",
       "           -1.0391e+00,  2.4219e+00],\n",
       "          [ 2.8320e-02,  6.8359e-02,  1.3477e-01,  ...,  6.0547e-01,\n",
       "           -1.1016e+00, -5.6250e-01]],\n",
       "\n",
       "         [[-3.5400e-02,  4.6082e-03, -8.6060e-03,  ..., -5.8203e-01,\n",
       "           -7.4219e-01, -2.3730e-01],\n",
       "          [-1.9531e-03,  4.5166e-02,  2.3535e-01,  ..., -3.4375e+00,\n",
       "            9.6875e-01,  1.1172e+00],\n",
       "          [-7.8125e-01, -2.7344e-01,  7.2754e-02,  ..., -2.2031e+00,\n",
       "            1.8516e+00, -3.6328e-01],\n",
       "          ...,\n",
       "          [ 5.5859e-01, -1.6406e-01, -7.2266e-02,  ...,  2.9531e+00,\n",
       "            1.7109e+00,  8.1250e-01],\n",
       "          [ 4.1406e-01,  2.0117e-01,  2.0386e-02,  ...,  2.2188e+00,\n",
       "            3.7500e+00,  1.3125e+00],\n",
       "          [-4.1016e-02,  1.6479e-02, -4.9072e-02,  ...,  2.5781e+00,\n",
       "            3.1250e+00,  1.0781e+00]],\n",
       "\n",
       "         [[ 2.3193e-02, -2.2736e-03,  1.2146e-02,  ..., -1.2109e-01,\n",
       "            7.9102e-02,  3.5938e-01],\n",
       "          [-3.4375e-01,  3.5547e-01,  1.9336e-01,  ..., -1.5000e+00,\n",
       "            2.4219e+00, -1.3516e+00],\n",
       "          [-1.1914e-01,  6.8750e-01,  2.2461e-01,  ..., -1.9375e+00,\n",
       "            3.5469e+00, -1.2988e-01],\n",
       "          ...,\n",
       "          [ 4.9561e-02,  1.2354e-01, -2.3340e-01,  ...,  3.8750e+00,\n",
       "            2.4219e+00,  2.1406e+00],\n",
       "          [-3.3789e-01, -8.4375e-01, -2.8906e-01,  ...,  2.1562e+00,\n",
       "            2.5000e+00,  5.5469e-01],\n",
       "          [ 6.6797e-01,  7.0801e-02,  1.0391e+00,  ...,  3.3750e+00,\n",
       "            3.5625e+00, -1.1328e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.3477e-02,  1.1978e-03, -2.0386e-02,  ..., -4.4922e-01,\n",
       "           -7.9102e-02,  4.0430e-01],\n",
       "          [-2.2461e-01,  6.2500e-02,  5.3223e-02,  ...,  3.8086e-01,\n",
       "           -2.9531e+00, -4.2812e+00],\n",
       "          [-6.6895e-02, -9.0625e-01,  6.8359e-01,  ...,  1.3906e+00,\n",
       "           -3.1094e+00, -3.0156e+00],\n",
       "          ...,\n",
       "          [ 6.5625e-01, -5.6396e-02, -5.7422e-01,  ...,  8.5938e-01,\n",
       "            4.3750e-01, -1.6406e+00],\n",
       "          [ 8.7500e-01,  5.0781e-01, -2.1777e-01,  ...,  1.1875e+00,\n",
       "            8.9062e-01, -2.0625e+00],\n",
       "          [ 6.3477e-03, -1.3965e-01,  5.1270e-02,  ..., -1.2188e+00,\n",
       "           -7.2656e-01,  1.6113e-01]],\n",
       "\n",
       "         [[-1.8921e-03,  1.4893e-02,  4.9561e-02,  ...,  7.2656e-01,\n",
       "            3.9062e-01, -2.3594e+00],\n",
       "          [ 2.8711e-01,  9.8145e-02, -9.7656e-02,  ...,  8.9844e-01,\n",
       "           -3.9375e+00,  4.5938e+00],\n",
       "          [ 8.0078e-02, -1.4062e-01,  1.5332e-01,  ...,  9.0234e-01,\n",
       "           -3.6875e+00,  6.0938e+00],\n",
       "          ...,\n",
       "          [ 4.0039e-01,  4.8828e-03, -1.0742e-01,  ...,  1.0625e+00,\n",
       "            2.9492e-01,  4.5938e+00],\n",
       "          [ 6.1035e-02,  5.6641e-02,  3.3203e-02,  ..., -7.3047e-01,\n",
       "            9.9609e-01,  4.5312e+00],\n",
       "          [-9.1797e-02, -3.5645e-02, -5.3906e-01,  ..., -1.4062e+00,\n",
       "           -2.1387e-01,  5.1562e+00]],\n",
       "\n",
       "         [[-4.4861e-03,  1.3123e-02,  1.3062e-02,  ...,  2.9102e-01,\n",
       "           -1.7969e-01, -1.2500e-01],\n",
       "          [ 4.9805e-01, -1.0254e-01,  1.7212e-02,  ...,  2.3750e+00,\n",
       "            5.3516e-01, -1.3438e+00],\n",
       "          [-4.2188e-01, -5.2734e-01,  2.6172e-01,  ...,  1.6250e+00,\n",
       "            1.0938e+00, -1.4453e+00],\n",
       "          ...,\n",
       "          [ 3.5156e-02,  2.3535e-01, -7.3242e-03,  ..., -1.6641e+00,\n",
       "            1.8203e+00,  2.8711e-01],\n",
       "          [ 3.8477e-01, -1.2891e-01, -3.9258e-01,  ..., -9.4922e-01,\n",
       "            2.3594e+00,  2.5938e+00],\n",
       "          [ 1.9775e-02,  7.9590e-02, -6.9336e-02,  ...,  8.7500e-01,\n",
       "            1.5234e+00,  2.2344e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-1.0071e-02, -6.5308e-03, -1.2146e-02,  ...,  1.1108e-02,\n",
       "            1.5381e-02, -4.9438e-03],\n",
       "          [ 6.4062e-01, -1.0859e+00, -8.2422e-01,  ...,  1.2578e+00,\n",
       "            3.7305e-01, -6.9922e-01],\n",
       "          [ 4.8633e-01, -5.8594e-01, -2.3438e-01,  ...,  8.8672e-01,\n",
       "            4.6289e-01, -2.0020e-01],\n",
       "          ...,\n",
       "          [ 5.3906e-01,  2.4872e-03, -4.1797e-01,  ...,  2.8906e-01,\n",
       "           -1.7480e-01, -9.2188e-01],\n",
       "          [ 1.1182e-01, -1.4453e-01, -2.4805e-01,  ..., -5.0781e-01,\n",
       "            9.2773e-02, -3.0078e-01],\n",
       "          [ 7.8906e-01,  7.3047e-01,  3.2031e-01,  ..., -7.0703e-01,\n",
       "            2.2266e-01, -1.6797e-01]],\n",
       "\n",
       "         [[-3.7109e-02, -1.7822e-02,  3.5553e-03,  ..., -1.0071e-02,\n",
       "            4.2725e-03,  1.2634e-02],\n",
       "          [-4.8242e-01,  2.9907e-02,  5.5078e-01,  ..., -9.8145e-02,\n",
       "           -4.6094e-01,  4.2773e-01],\n",
       "          [ 2.9688e-01, -6.6797e-01, -1.8848e-01,  ...,  6.1719e-01,\n",
       "            7.3047e-01,  6.8750e-01],\n",
       "          ...,\n",
       "          [-5.6250e-01,  3.4424e-02, -1.2598e-01,  ...,  7.9688e-01,\n",
       "           -6.0938e-01, -5.1953e-01],\n",
       "          [ 7.5781e-01, -5.9766e-01,  4.4531e-01,  ...,  5.5078e-01,\n",
       "            1.0078e+00,  6.2109e-01],\n",
       "          [ 7.0703e-01,  4.1748e-02,  9.2969e-01,  ...,  4.6484e-01,\n",
       "           -1.0234e+00,  6.9141e-01]],\n",
       "\n",
       "         [[ 1.1063e-03, -8.3542e-04, -1.8311e-02,  ..., -1.4404e-02,\n",
       "           -1.9653e-02, -2.1362e-03],\n",
       "          [ 1.6797e-01,  3.9648e-01,  6.8359e-02,  ...,  1.1816e-01,\n",
       "            5.4199e-02,  4.9316e-02],\n",
       "          [-2.1191e-01,  3.6133e-01,  6.1768e-02,  ...,  4.6094e-01,\n",
       "            1.6504e-01, -1.9922e-01],\n",
       "          ...,\n",
       "          [ 1.3086e-01, -3.9453e-01, -4.4727e-01,  ...,  4.1211e-01,\n",
       "           -4.3945e-01,  6.2988e-02],\n",
       "          [ 2.9688e-01, -1.7480e-01, -2.0605e-01,  ...,  2.4609e-01,\n",
       "           -2.8711e-01, -1.5430e-01],\n",
       "          [-2.5781e-01,  2.3242e-01, -5.8594e-01,  ..., -5.8350e-02,\n",
       "           -6.5625e-01,  3.4570e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.1902e-02,  1.7212e-02, -1.2207e-02,  ..., -2.6978e-02,\n",
       "           -1.4267e-03,  1.0010e-02],\n",
       "          [ 3.9844e-01, -1.0938e-01,  1.9165e-02,  ..., -2.7344e-01,\n",
       "           -9.8267e-03, -4.0820e-01],\n",
       "          [-2.5977e-01,  1.8164e-01, -1.0559e-02,  ..., -4.1992e-01,\n",
       "            3.3789e-01, -9.0234e-01],\n",
       "          ...,\n",
       "          [-2.2461e-02,  5.2734e-01, -3.3008e-01,  ...,  6.5234e-01,\n",
       "           -1.6016e-01, -4.7656e-01],\n",
       "          [-5.0049e-02,  4.6484e-01, -7.7344e-01,  ..., -1.5625e-01,\n",
       "            6.4941e-02, -1.1914e-01],\n",
       "          [ 1.1328e+00,  4.7852e-01, -1.0391e+00,  ...,  1.0469e+00,\n",
       "            1.4551e-01,  5.8984e-01]],\n",
       "\n",
       "         [[-1.9226e-03, -1.4465e-02, -9.0332e-03,  ...,  6.5308e-03,\n",
       "            2.4109e-03,  7.2632e-03],\n",
       "          [ 3.0078e-01,  1.4941e-01,  7.3047e-01,  ...,  3.9062e-01,\n",
       "            3.0664e-01, -3.0859e-01],\n",
       "          [ 1.0449e-01,  8.1641e-01,  2.7930e-01,  ..., -1.7773e-01,\n",
       "            2.0312e-01, -3.4375e-01],\n",
       "          ...,\n",
       "          [-2.7734e-01, -3.7109e-02,  3.2812e-01,  ...,  9.1016e-01,\n",
       "            1.4941e-01,  4.4922e-01],\n",
       "          [ 9.5312e-01,  1.1719e-01,  8.2031e-01,  ...,  7.0312e-01,\n",
       "            1.1621e-01,  2.3828e-01],\n",
       "          [ 5.4297e-01, -8.3496e-02,  1.2266e+00,  ...,  1.5234e+00,\n",
       "           -8.5449e-03, -3.6523e-01]],\n",
       "\n",
       "         [[-1.2878e-02,  1.1353e-02, -1.1377e-01,  ...,  5.4169e-04,\n",
       "           -3.4668e-02, -4.2114e-03],\n",
       "          [ 4.2578e-01, -5.6250e-01,  4.3164e-01,  ...,  4.2236e-02,\n",
       "            3.4766e-01,  7.5000e-01],\n",
       "          [ 6.3672e-01, -6.5625e-01,  2.8711e-01,  ...,  4.8828e-01,\n",
       "            1.3379e-01,  6.0938e-01],\n",
       "          ...,\n",
       "          [-9.5703e-01, -5.5469e-01,  3.8281e-01,  ...,  4.0625e-01,\n",
       "           -2.5391e-01, -5.7812e-01],\n",
       "          [-2.4902e-01, -2.8711e-01,  3.6914e-01,  ...,  8.8379e-02,\n",
       "            8.3496e-02, -1.0559e-02],\n",
       "          [ 9.0234e-01, -1.9043e-01, -3.4766e-01,  ...,  4.7119e-02,\n",
       "            1.1875e+00, -1.1250e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-2.2217e-02,  1.0559e-02, -1.5015e-02,  ...,  5.9375e-01,\n",
       "           -3.5156e-01, -6.8359e-01],\n",
       "          [-2.1582e-01, -7.8125e-02,  4.2969e-02,  ..., -4.0625e-01,\n",
       "           -1.3516e+00, -2.3145e-01],\n",
       "          [-2.3535e-01, -5.0391e-01, -4.9609e-01,  ..., -1.6875e+00,\n",
       "           -2.0469e+00, -3.2715e-02],\n",
       "          ...,\n",
       "          [ 2.0117e-01, -2.4902e-01, -6.5625e-01,  ..., -2.5000e+00,\n",
       "            4.3555e-01, -1.1953e+00],\n",
       "          [ 2.3438e-01,  1.0254e-01, -2.7344e-01,  ..., -3.0469e+00,\n",
       "            8.8281e-01,  7.0703e-01],\n",
       "          [ 2.0020e-01, -1.3574e-01, -7.5195e-02,  ..., -3.2969e+00,\n",
       "           -5.3516e-01, -1.9922e+00]],\n",
       "\n",
       "         [[-9.3994e-03, -8.1177e-03, -1.4526e-02,  ...,  6.0303e-02,\n",
       "           -1.2207e-01, -8.3008e-02],\n",
       "          [ 2.8516e-01,  8.6719e-01,  1.3672e-02,  ..., -6.8359e-02,\n",
       "           -1.1562e+00,  2.7734e-01],\n",
       "          [ 2.6367e-01,  1.9531e-01, -8.8867e-02,  ..., -4.9805e-02,\n",
       "           -1.6641e+00,  6.5625e-01],\n",
       "          ...,\n",
       "          [-8.7500e-01,  5.0391e-01,  2.4219e-01,  ..., -1.1250e+00,\n",
       "           -2.7500e+00,  4.7852e-01],\n",
       "          [ 4.6484e-01,  6.5625e-01, -5.0781e-01,  ...,  2.8711e-01,\n",
       "           -2.1250e+00,  8.9453e-01],\n",
       "          [ 4.0625e-01,  3.3203e-01, -1.2988e-01,  ..., -6.8359e-01,\n",
       "           -5.4297e-01,  1.1875e+00]],\n",
       "\n",
       "         [[ 1.3733e-02,  3.8574e-02, -7.5073e-03,  ..., -5.6250e-01,\n",
       "            1.1963e-01,  2.5977e-01],\n",
       "          [ 1.0156e-01, -3.9453e-01,  1.7090e-02,  ...,  1.9297e+00,\n",
       "            4.7266e-01, -2.1094e+00],\n",
       "          [ 4.7656e-01, -6.5625e-01,  6.2891e-01,  ...,  2.3594e+00,\n",
       "           -1.3438e+00, -1.7188e+00],\n",
       "          ...,\n",
       "          [-3.7109e-01,  3.3789e-01,  3.3008e-01,  ...,  2.0625e+00,\n",
       "           -6.9141e-01,  2.3926e-01],\n",
       "          [-7.0703e-01,  8.0469e-01,  6.8359e-03,  ...,  1.8984e+00,\n",
       "           -1.7891e+00, -2.6875e+00],\n",
       "          [ 4.7266e-01,  1.6309e-01, -4.3945e-01,  ...,  8.2031e-01,\n",
       "           -6.2500e-01, -3.4531e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.4191e-03,  9.4604e-03,  1.9989e-03,  ...,  7.3730e-02,\n",
       "           -3.6621e-02, -6.6406e-02],\n",
       "          [ 2.9492e-01,  4.8828e-02,  5.0781e-01,  ..., -2.4805e-01,\n",
       "            2.9297e-01,  5.1953e-01],\n",
       "          [ 5.1562e-01,  3.2812e-01,  2.8564e-02,  ...,  3.2422e-01,\n",
       "            1.0234e+00,  4.1992e-01],\n",
       "          ...,\n",
       "          [-9.0625e-01,  2.3281e+00, -5.9570e-02,  ..., -3.3447e-02,\n",
       "           -8.5156e-01, -2.8516e-01],\n",
       "          [ 1.7188e+00,  3.9648e-01, -1.3047e+00,  ...,  6.4062e-01,\n",
       "           -1.3965e-01, -1.0234e+00],\n",
       "          [-2.8320e-01,  2.3242e-01,  1.0059e-01,  ..., -1.2734e+00,\n",
       "           -1.3828e+00,  1.2969e+00]],\n",
       "\n",
       "         [[ 4.1260e-02,  1.1230e-02, -7.8735e-03,  ...,  1.9824e-01,\n",
       "           -1.1279e-01,  5.9326e-02],\n",
       "          [-3.5889e-02,  1.6016e-01, -4.5312e-01,  ..., -1.0469e+00,\n",
       "            8.5156e-01, -1.8555e-01],\n",
       "          [-2.5513e-02,  4.1211e-01, -4.2773e-01,  ..., -2.0410e-01,\n",
       "            7.3047e-01, -5.6250e-01],\n",
       "          ...,\n",
       "          [-1.5332e-01, -2.4316e-01, -2.6758e-01,  ..., -9.7266e-01,\n",
       "           -1.2031e+00,  8.3984e-01],\n",
       "          [ 4.5703e-01, -6.6406e-01, -6.9824e-02,  ..., -2.1094e+00,\n",
       "           -8.9844e-01,  1.5391e+00],\n",
       "          [ 3.3984e-01,  8.0566e-02,  6.1719e-01,  ...,  2.5781e+00,\n",
       "           -2.5156e+00,  1.7734e+00]],\n",
       "\n",
       "         [[ 1.2451e-02,  4.6082e-03, -2.5513e-02,  ...,  2.2461e-01,\n",
       "           -1.2891e+00,  2.4414e-01],\n",
       "          [-5.7812e-01,  3.2031e-01, -3.0273e-01,  ..., -1.1328e+00,\n",
       "           -1.5000e+00, -2.5938e+00],\n",
       "          [-2.0605e-01, -1.5991e-02,  1.3574e-01,  ..., -2.2188e+00,\n",
       "           -1.1250e+00, -1.7109e+00],\n",
       "          ...,\n",
       "          [-1.6602e-01, -5.4688e-02,  4.7852e-01,  ..., -1.5938e+00,\n",
       "            1.2656e+00, -3.0156e+00],\n",
       "          [ 3.1445e-01,  2.1484e-01, -1.9531e-01,  ..., -1.1172e+00,\n",
       "            1.9688e+00, -3.3438e+00],\n",
       "          [ 3.9844e-01,  1.4160e-01, -3.6719e-01,  ...,  5.8594e-02,\n",
       "            1.3281e+00, -6.0156e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-4.5471e-03,  3.0859e-01,  5.1880e-03,  ...,  3.4180e-02,\n",
       "            2.2827e-02, -1.1047e-02],\n",
       "          [-1.9062e+00, -1.4922e+00,  2.5586e-01,  ...,  2.8906e-01,\n",
       "           -7.9102e-02,  1.5918e-01],\n",
       "          [-1.0312e+00, -1.1094e+00,  5.3516e-01,  ...,  1.2891e-01,\n",
       "           -2.4609e-01,  1.2109e-01],\n",
       "          ...,\n",
       "          [-3.4180e-01, -1.5703e+00,  6.0156e-01,  ...,  3.1250e-01,\n",
       "            3.9453e-01, -1.7456e-02],\n",
       "          [-1.0156e+00, -1.0234e+00, -5.5908e-02,  ...,  1.2598e-01,\n",
       "            1.1953e+00, -5.5469e-01],\n",
       "          [-1.2422e+00, -1.4922e+00,  2.0312e-01,  ..., -2.1240e-02,\n",
       "            9.1406e-01, -7.0703e-01]],\n",
       "\n",
       "         [[ 1.2329e-02,  2.0386e-02, -2.6367e-02,  ..., -1.9836e-03,\n",
       "           -5.3406e-03, -2.9297e-01],\n",
       "          [-3.9258e-01,  3.9062e-02,  1.0986e-01,  ..., -1.4551e-01,\n",
       "            1.3733e-02,  1.3594e+00],\n",
       "          [-2.4512e-01,  2.5391e-01, -8.1055e-02,  ..., -2.4805e-01,\n",
       "           -6.1328e-01,  1.5000e+00],\n",
       "          ...,\n",
       "          [ 9.6484e-01, -9.1309e-02, -6.3281e-01,  ..., -1.5703e+00,\n",
       "            2.1387e-01,  1.6641e+00],\n",
       "          [ 8.3594e-01, -2.4316e-01, -1.8555e-02,  ..., -5.8203e-01,\n",
       "            1.9336e-01,  1.3984e+00],\n",
       "          [ 1.0234e+00, -1.0859e+00, -8.7109e-01,  ..., -6.5625e-01,\n",
       "            3.0859e-01,  1.1641e+00]],\n",
       "\n",
       "         [[-5.3101e-03,  2.8076e-03, -8.6670e-03,  ...,  9.7656e-03,\n",
       "           -3.0823e-03, -1.7822e-02],\n",
       "          [ 2.5391e-01, -5.5469e-01,  3.4766e-01,  ..., -7.0312e-02,\n",
       "            8.6719e-01, -3.1055e-01],\n",
       "          [ 3.9062e-01, -6.1035e-02,  2.1191e-01,  ...,  3.6133e-02,\n",
       "            1.4453e-01,  8.8501e-03],\n",
       "          ...,\n",
       "          [ 5.5469e-01, -5.7422e-01, -1.5781e+00,  ...,  1.1182e-01,\n",
       "            3.6719e-01,  5.0000e-01],\n",
       "          [ 7.8125e-01,  4.5312e-01, -7.6172e-01,  ...,  1.7188e-01,\n",
       "            4.4141e-01,  2.0996e-01],\n",
       "          [-4.3457e-02,  1.3489e-02, -5.3906e-01,  ..., -3.1641e-01,\n",
       "            2.9907e-02,  2.1875e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.1973e-02,  6.0425e-03,  5.1575e-03,  ..., -3.3203e-02,\n",
       "           -9.4604e-03, -1.3611e-02],\n",
       "          [-9.9609e-02, -8.0469e-01,  2.3926e-01,  ...,  1.1841e-02,\n",
       "            1.6113e-01, -5.1562e-01],\n",
       "          [ 2.0898e-01,  7.2327e-03,  2.8125e-01,  ..., -3.0469e-01,\n",
       "           -7.1289e-02, -3.0664e-01],\n",
       "          ...,\n",
       "          [ 7.6172e-02, -8.3594e-01,  4.1260e-02,  ...,  6.7188e-01,\n",
       "            5.1172e-01,  3.3789e-01],\n",
       "          [ 8.0469e-01, -6.3281e-01, -4.7070e-01,  ...,  6.2109e-01,\n",
       "           -6.9922e-01,  5.7031e-01],\n",
       "          [ 7.8906e-01, -3.9258e-01,  2.2363e-01,  ...,  1.5234e+00,\n",
       "            6.3672e-01, -4.5117e-01]],\n",
       "\n",
       "         [[-1.7700e-02, -1.0742e-02, -1.1475e-02,  ..., -8.9111e-03,\n",
       "           -3.7695e-01, -6.5994e-04],\n",
       "          [ 8.1543e-02,  7.3730e-02,  2.9297e-01,  ..., -2.2363e-01,\n",
       "            1.7188e-01, -3.2812e-01],\n",
       "          [-1.4844e-01,  4.9072e-02,  3.1836e-01,  ..., -1.9922e-01,\n",
       "            2.8906e-01, -3.1055e-01],\n",
       "          ...,\n",
       "          [ 3.6133e-01, -7.4219e-01,  4.0039e-01,  ..., -5.1172e-01,\n",
       "            1.2031e+00,  5.9375e-01],\n",
       "          [-6.3965e-02, -4.4531e-01,  5.1562e-01,  ..., -5.4688e-01,\n",
       "            9.4141e-01,  6.4453e-01],\n",
       "          [-3.4180e-01, -4.4727e-01,  6.9922e-01,  ...,  1.3379e-01,\n",
       "            8.5547e-01,  3.7500e-01]],\n",
       "\n",
       "         [[-1.8845e-03, -1.0620e-02,  2.2461e-02,  ..., -1.0071e-02,\n",
       "           -5.8594e-03,  1.0071e-02],\n",
       "          [ 3.5156e-01,  7.6660e-02, -6.2891e-01,  ..., -5.0391e-01,\n",
       "           -2.1289e-01, -9.1406e-01],\n",
       "          [ 4.2188e-01,  5.1953e-01, -2.4609e-01,  ...,  8.8867e-02,\n",
       "            6.0059e-02, -4.0625e-01],\n",
       "          ...,\n",
       "          [-4.6387e-02,  8.3496e-02,  2.3535e-01,  ...,  6.7578e-01,\n",
       "            2.7539e-01, -2.9883e-01],\n",
       "          [ 5.7422e-01, -8.3496e-02, -2.3535e-01,  ..., -4.9316e-02,\n",
       "           -4.1992e-01,  3.0664e-01],\n",
       "          [ 3.0151e-02,  8.0078e-01, -5.0391e-01,  ...,  3.7305e-01,\n",
       "           -5.3906e-01, -1.5137e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.2715e-02,  3.0273e-02, -3.0762e-02,  ...,  3.5742e-01,\n",
       "            7.4219e-02,  1.4844e-01],\n",
       "          [-3.6914e-01, -3.5938e-01,  5.8594e-01,  ..., -3.7500e-01,\n",
       "            2.4688e+00,  1.0391e+00],\n",
       "          [-2.7148e-01,  6.7578e-01,  3.0859e-01,  ..., -1.0703e+00,\n",
       "            1.0938e+00, -1.7969e-01],\n",
       "          ...,\n",
       "          [-3.7109e-01,  1.0547e-01, -4.9023e-01,  ..., -8.8281e-01,\n",
       "           -2.1582e-01, -2.1250e+00],\n",
       "          [ 2.1606e-02,  3.9648e-01,  5.2734e-01,  ..., -9.1016e-01,\n",
       "           -1.0498e-01, -5.7031e-01],\n",
       "          [ 1.9434e-01, -5.3125e-01, -6.1328e-01,  ...,  1.8281e+00,\n",
       "            1.0312e+00, -3.6719e-01]],\n",
       "\n",
       "         [[-2.3438e-02,  5.9814e-03, -1.3855e-02,  ...,  3.5938e-01,\n",
       "            4.1992e-01,  6.2891e-01],\n",
       "          [-1.3867e-01,  4.1406e-01,  3.7695e-01,  ..., -9.5312e-01,\n",
       "            8.3594e-01, -2.1875e+00],\n",
       "          [-4.6680e-01,  4.2188e-01,  1.2793e-01,  ..., -1.2578e+00,\n",
       "            1.1484e+00, -2.4844e+00],\n",
       "          ...,\n",
       "          [-2.2461e-01,  1.5039e-01, -9.2969e-01,  ..., -2.0469e+00,\n",
       "           -1.0859e+00, -7.8125e-01],\n",
       "          [-1.0791e-01,  1.1523e-01, -5.1172e-01,  ..., -1.0078e+00,\n",
       "           -6.9922e-01,  4.3750e-01],\n",
       "          [-3.9258e-01,  2.7344e-01, -5.8203e-01,  ..., -2.0625e+00,\n",
       "           -6.0547e-01,  1.3359e+00]],\n",
       "\n",
       "         [[-3.1281e-04,  3.0151e-02,  1.2573e-02,  ..., -1.6724e-02,\n",
       "            1.7969e-01,  3.0273e-01],\n",
       "          [-1.2695e-01,  2.9688e-01,  4.0039e-01,  ...,  1.0938e+00,\n",
       "           -5.0391e-01, -3.3594e-01],\n",
       "          [-3.6328e-01,  3.0273e-01,  5.4297e-01,  ...,  1.3125e+00,\n",
       "           -3.0664e-01,  1.2656e+00],\n",
       "          ...,\n",
       "          [ 3.5938e-01,  2.6367e-01, -7.7344e-01,  ..., -3.7891e-01,\n",
       "            1.2793e-01,  1.5000e+00],\n",
       "          [ 7.5781e-01,  2.4707e-01, -2.2949e-02,  ...,  8.8672e-01,\n",
       "            6.1719e-01,  7.0703e-01],\n",
       "          [ 7.4707e-02,  6.6797e-01,  6.2500e-02,  ...,  8.0859e-01,\n",
       "           -1.3281e+00,  1.3672e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.5757e-02,  6.9580e-03, -1.1230e-02,  ..., -3.8574e-02,\n",
       "           -2.2559e-01, -4.9609e-01],\n",
       "          [-7.0312e-02, -6.0547e-01, -4.6289e-01,  ..., -1.0498e-01,\n",
       "            1.1562e+00, -3.5742e-01],\n",
       "          [-1.0625e+00, -1.1641e+00,  2.4170e-02,  ..., -4.4531e-01,\n",
       "            1.7031e+00, -3.5742e-01],\n",
       "          ...,\n",
       "          [-2.2852e-01, -4.7266e-01, -2.7734e-01,  ..., -1.0547e+00,\n",
       "            3.9844e-01,  4.4727e-01],\n",
       "          [-7.5000e-01,  4.8438e-01,  2.5586e-01,  ..., -3.8281e-01,\n",
       "            6.5918e-02,  1.1484e+00],\n",
       "          [ 1.1133e-01,  1.3672e-02, -3.2812e-01,  ..., -4.4336e-01,\n",
       "           -6.7188e-01,  1.6113e-01]],\n",
       "\n",
       "         [[ 2.0996e-02, -2.8076e-02, -3.2227e-02,  ...,  7.4219e-02,\n",
       "           -1.6797e+00,  1.4648e-02],\n",
       "          [-1.8457e-01, -2.7539e-01,  9.6191e-02,  ..., -1.4062e+00,\n",
       "            2.9688e+00,  1.3203e+00],\n",
       "          [ 6.7383e-02, -2.2217e-02,  3.6719e-01,  ..., -2.0469e+00,\n",
       "            4.0938e+00,  7.6953e-01],\n",
       "          ...,\n",
       "          [-1.8164e-01,  1.0498e-01, -4.4531e-01,  ...,  5.2344e-01,\n",
       "            4.5938e+00,  1.0234e+00],\n",
       "          [-1.2793e-01, -1.0449e-01, -1.6211e-01,  ..., -2.0156e+00,\n",
       "            4.5938e+00,  2.8125e+00],\n",
       "          [-3.0762e-02, -2.7344e-01, -3.3203e-01,  ..., -1.9531e+00,\n",
       "            4.6250e+00,  2.1719e+00]],\n",
       "\n",
       "         [[ 5.3406e-03,  2.1973e-02,  1.1230e-02,  ...,  2.0801e-01,\n",
       "            5.8203e-01,  1.7090e-02],\n",
       "          [ 9.5215e-02,  2.3828e-01,  2.0898e-01,  ..., -1.6328e+00,\n",
       "           -6.7188e-01, -1.3281e+00],\n",
       "          [ 1.1768e-01, -2.0215e-01,  5.4443e-02,  ..., -2.2812e+00,\n",
       "           -3.6914e-01, -1.3203e+00],\n",
       "          ...,\n",
       "          [-1.8066e-01,  1.1670e-01, -1.4844e-01,  ..., -1.1875e+00,\n",
       "           -2.0000e+00, -1.8516e+00],\n",
       "          [-4.2188e-01,  1.9531e-03,  6.6016e-01,  ..., -2.0264e-02,\n",
       "           -1.9844e+00, -2.6250e+00],\n",
       "          [-1.3477e-01,  1.6016e-01,  6.6406e-01,  ..., -6.4453e-02,\n",
       "           -2.2500e+00, -3.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.8555e-02, -1.5625e-02,  9.8267e-03,  ..., -5.4016e-03,\n",
       "           -6.5308e-03, -9.3994e-03],\n",
       "          [-3.7500e-01,  9.8047e-01,  6.0547e-01,  ...,  5.6641e-01,\n",
       "            1.1719e+00,  6.4844e-01],\n",
       "          [-4.2578e-01,  5.8984e-01,  3.0078e-01,  ...,  1.0938e+00,\n",
       "            7.0312e-01,  5.8984e-01],\n",
       "          ...,\n",
       "          [ 5.7812e-01,  1.7285e-01, -2.7734e-01,  ...,  3.2227e-01,\n",
       "           -7.6660e-02,  4.4141e-01],\n",
       "          [ 6.7188e-01,  4.5117e-01, -1.7773e-01,  ..., -5.2490e-02,\n",
       "           -2.6758e-01,  3.0859e-01],\n",
       "          [-5.7031e-01, -5.8594e-01, -1.2695e-01,  ...,  2.3535e-01,\n",
       "           -3.1055e-01, -1.0059e-01]],\n",
       "\n",
       "         [[ 1.3916e-02, -7.7515e-03,  8.1787e-03,  ...,  3.6926e-03,\n",
       "           -7.8735e-03, -1.0742e-02],\n",
       "          [ 7.9297e-01, -2.7539e-01,  6.5430e-02,  ...,  6.5918e-02,\n",
       "            3.7842e-02, -3.3984e-01],\n",
       "          [ 9.4141e-01,  2.5269e-02,  9.0332e-02,  ...,  3.5400e-02,\n",
       "            3.6133e-02,  2.3438e-01],\n",
       "          ...,\n",
       "          [ 5.8984e-01, -8.4839e-03, -1.0303e-01,  ...,  4.8828e-01,\n",
       "           -5.6641e-01,  4.0039e-01],\n",
       "          [ 6.5625e-01,  5.7812e-01, -3.9844e-01,  ..., -2.0142e-02,\n",
       "            4.2188e-01,  6.0059e-02],\n",
       "          [ 1.3867e-01,  2.1094e-01, -5.7031e-01,  ..., -8.1250e-01,\n",
       "            9.7656e-02,  1.5625e-01]],\n",
       "\n",
       "         [[-6.4087e-04,  1.2329e-02,  1.5625e-02,  ...,  1.5869e-02,\n",
       "           -1.8066e-02, -4.3457e-02],\n",
       "          [-2.8125e-01, -2.5195e-01, -8.3984e-02,  ...,  3.3203e-01,\n",
       "           -2.5195e-01,  1.3672e+00],\n",
       "          [ 6.0156e-01, -1.7480e-01,  1.5332e-01,  ..., -1.0303e-01,\n",
       "            5.7812e-01,  7.8906e-01],\n",
       "          ...,\n",
       "          [ 9.8828e-01, -1.0303e-01, -3.6328e-01,  ..., -2.1484e-01,\n",
       "            2.5586e-01,  3.5547e-01],\n",
       "          [ 8.3594e-01, -2.7344e-01,  1.9688e+00,  ..., -3.2422e-01,\n",
       "           -8.0469e-01,  5.8203e-01],\n",
       "          [ 4.2773e-01, -1.9336e-01,  4.1797e-01,  ...,  2.9883e-01,\n",
       "           -1.1133e-01, -3.5352e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-4.6387e-03,  9.0332e-03,  3.5248e-03,  ..., -1.3489e-02,\n",
       "           -1.1780e-02, -1.0925e-02],\n",
       "          [ 1.7578e-01,  9.3359e-01,  7.1094e-01,  ..., -5.5078e-01,\n",
       "            4.7656e-01, -6.3281e-01],\n",
       "          [-4.7852e-01,  6.4453e-01,  9.0820e-02,  ..., -4.7461e-01,\n",
       "            5.5859e-01,  1.0449e-01],\n",
       "          ...,\n",
       "          [ 5.2344e-01,  6.8750e-01,  3.8574e-02,  ..., -1.6895e-01,\n",
       "            6.3672e-01, -3.4375e-01],\n",
       "          [ 5.5908e-02,  9.0234e-01, -2.9297e-01,  ...,  6.9531e-01,\n",
       "            6.2109e-01, -2.4316e-01],\n",
       "          [ 2.2461e-01,  1.6504e-01, -8.6719e-01,  ..., -1.7578e-01,\n",
       "            4.7266e-01, -2.5757e-02]],\n",
       "\n",
       "         [[-3.0212e-03,  8.3618e-03, -3.7109e-02,  ...,  1.4526e-02,\n",
       "            8.6060e-03, -1.2329e-02],\n",
       "          [-2.2070e-01, -4.1406e-01, -1.0000e+00,  ..., -9.7266e-01,\n",
       "            3.5156e-01, -2.6367e-01],\n",
       "          [-1.6699e-01, -1.0547e-01, -1.6895e-01,  ..., -5.1562e-01,\n",
       "            7.1289e-02,  3.2715e-02],\n",
       "          ...,\n",
       "          [-4.0283e-02, -4.1602e-01, -4.4922e-01,  ...,  3.1055e-01,\n",
       "            8.1055e-02, -5.1953e-01],\n",
       "          [ 7.7734e-01, -7.9688e-01,  3.8867e-01,  ..., -3.0273e-01,\n",
       "           -3.0273e-01, -7.7344e-01],\n",
       "          [ 2.8906e-01,  2.3242e-01,  4.6094e-01,  ...,  1.7383e-01,\n",
       "            5.8203e-01, -2.4023e-01]],\n",
       "\n",
       "         [[ 5.6763e-03,  2.5269e-02, -8.2397e-03,  ...,  6.5231e-04,\n",
       "            1.6968e-02,  6.0425e-03],\n",
       "          [-3.3008e-01,  2.2559e-01, -1.9165e-02,  ..., -9.5703e-02,\n",
       "           -4.8047e-01, -4.4727e-01],\n",
       "          [-8.5449e-02,  6.2500e-01,  3.1836e-01,  ...,  1.3379e-01,\n",
       "           -5.1172e-01, -3.3594e-01],\n",
       "          ...,\n",
       "          [ 2.9688e-01, -1.7773e-01,  5.7031e-01,  ...,  8.0859e-01,\n",
       "            2.9297e-01, -3.0469e-01],\n",
       "          [ 1.3477e-01,  1.0107e-01, -4.5654e-02,  ...,  3.5547e-01,\n",
       "            1.2451e-01,  4.0430e-01],\n",
       "          [-1.0156e+00,  1.1816e-01, -3.2812e-01,  ...,  6.1328e-01,\n",
       "           -4.9219e-01, -8.3203e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.3447e-02, -6.4392e-03,  1.6235e-02,  ..., -2.3242e-01,\n",
       "            5.1562e-01,  7.4219e-01],\n",
       "          [-3.1055e-01, -1.0254e-01,  2.7930e-01,  ...,  1.1572e-01,\n",
       "           -1.1016e+00, -1.0938e+00],\n",
       "          [-7.7734e-01, -3.0469e-01, -3.1641e-01,  ...,  4.8438e-01,\n",
       "           -5.0391e-01, -6.3672e-01],\n",
       "          ...,\n",
       "          [-3.8086e-02,  9.5703e-02, -9.0625e-01,  ...,  8.7891e-01,\n",
       "           -9.0625e-01, -2.0781e+00],\n",
       "          [ 1.3438e+00,  3.1250e-01, -3.7891e-01,  ...,  4.2188e-01,\n",
       "           -1.2344e+00, -2.5000e+00],\n",
       "          [-1.3281e-01, -2.6562e-01, -6.7871e-02,  ..., -1.4746e-01,\n",
       "           -1.1719e+00,  4.6680e-01]],\n",
       "\n",
       "         [[-1.4496e-04,  1.6479e-03, -1.8555e-02,  ..., -2.9883e-01,\n",
       "            5.1758e-02,  1.4893e-02],\n",
       "          [-4.2969e-02,  1.1562e+00,  3.0469e-01,  ...,  2.1973e-01,\n",
       "           -1.2695e-01,  3.1445e-01],\n",
       "          [-3.2031e-01,  3.7891e-01, -1.0547e-01,  ...,  3.4570e-01,\n",
       "            3.7305e-01,  2.5195e-01],\n",
       "          ...,\n",
       "          [ 3.5938e-01, -3.8672e-01, -5.5078e-01,  ...,  1.1094e+00,\n",
       "            6.9922e-01,  2.9492e-01],\n",
       "          [-2.0703e-01, -3.5938e-01,  3.1250e-01,  ...,  5.5859e-01,\n",
       "           -1.3965e-01, -4.6484e-01],\n",
       "          [ 3.8281e-01,  3.1445e-01,  9.0234e-01,  ..., -9.4531e-01,\n",
       "            5.4688e-01,  2.1094e+00]],\n",
       "\n",
       "         [[ 1.1063e-03, -1.0925e-02, -4.8218e-03,  ..., -6.7383e-02,\n",
       "            1.8652e-01, -2.5757e-02],\n",
       "          [ 3.8330e-02,  9.6094e-01, -8.5156e-01,  ..., -6.2500e-01,\n",
       "           -1.1328e-01, -1.5078e+00],\n",
       "          [-2.9102e-01,  7.1484e-01, -4.5312e-01,  ...,  3.1055e-01,\n",
       "           -3.0859e-01, -1.3750e+00],\n",
       "          ...,\n",
       "          [-3.2617e-01, -4.2969e-01,  8.9844e-02,  ...,  1.5391e+00,\n",
       "           -8.5938e-01,  7.2266e-01],\n",
       "          [-4.4531e-01, -5.5469e-01, -4.4336e-01,  ...,  2.4512e-01,\n",
       "           -1.0078e+00, -6.3672e-01],\n",
       "          [-1.5820e-01,  1.2891e-01, -2.3828e-01,  ..., -1.7891e+00,\n",
       "           -5.2344e-01, -4.0234e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 9.7046e-03, -4.6387e-03,  1.1169e-02,  ..., -1.4258e-01,\n",
       "           -6.9336e-02,  4.5166e-02],\n",
       "          [ 1.7188e-01, -6.9141e-01,  8.5938e-02,  ..., -1.4453e-01,\n",
       "           -2.1973e-01,  2.8198e-02],\n",
       "          [-3.7891e-01,  9.1797e-01,  5.6641e-01,  ...,  6.8750e-01,\n",
       "           -1.0391e+00, -8.7500e-01],\n",
       "          ...,\n",
       "          [-4.1211e-01, -1.0156e+00,  2.5391e-01,  ...,  1.3672e-01,\n",
       "           -9.7656e-01, -4.2969e-01],\n",
       "          [ 1.8848e-01, -3.2031e-01, -4.4531e-01,  ..., -2.9102e-01,\n",
       "           -6.9922e-01, -7.3828e-01],\n",
       "          [-4.7461e-01,  9.2188e-01, -1.8906e+00,  ..., -1.7422e+00,\n",
       "           -6.4844e-01, -1.0254e-02]],\n",
       "\n",
       "         [[-3.8300e-03, -2.9373e-04, -8.9722e-03,  ...,  7.1289e-02,\n",
       "           -1.6895e-01, -2.5586e-01],\n",
       "          [-2.0410e-01, -1.4258e-01,  2.8906e-01,  ..., -1.7285e-01,\n",
       "            5.8203e-01, -1.9141e-01],\n",
       "          [ 1.3281e-01, -3.9453e-01,  8.2031e-01,  ..., -5.3125e-01,\n",
       "           -2.2363e-01,  3.1055e-01],\n",
       "          ...,\n",
       "          [-7.1094e-01,  1.1426e-01,  1.6699e-01,  ...,  1.4609e+00,\n",
       "            1.7266e+00,  3.8086e-01],\n",
       "          [ 2.1191e-01,  2.8516e-01,  2.5000e-01,  ..., -1.3965e-01,\n",
       "            1.2188e+00,  9.9609e-01],\n",
       "          [ 9.1016e-01,  0.0000e+00,  5.5859e-01,  ..., -9.5312e-01,\n",
       "            7.4219e-01,  7.3047e-01]],\n",
       "\n",
       "         [[-3.0670e-03,  1.2085e-02, -1.9653e-02,  ...,  2.2949e-01,\n",
       "            2.0605e-01,  8.7891e-02],\n",
       "          [ 3.8672e-01,  1.0156e-01,  6.2500e-01,  ...,  2.3047e-01,\n",
       "           -1.3855e-02,  1.9766e+00],\n",
       "          [ 9.3750e-02, -5.2734e-01,  9.7656e-01,  ..., -1.1963e-01,\n",
       "           -7.5391e-01,  2.9375e+00],\n",
       "          ...,\n",
       "          [-6.8750e-01, -6.5625e-01, -3.5547e-01,  ...,  7.8125e-01,\n",
       "           -2.2559e-01,  1.0938e-01],\n",
       "          [-3.1055e-01, -4.4141e-01, -8.3984e-01,  ...,  9.1406e-01,\n",
       "            2.3145e-01, -9.1797e-01],\n",
       "          [ 1.8164e-01,  5.2734e-01, -3.0078e-01,  ...,  2.1719e+00,\n",
       "            2.4062e+00,  1.9141e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 2.9449e-03,  5.4626e-03,  4.4250e-03,  ..., -2.3346e-03,\n",
       "           -5.9204e-03,  5.1270e-03],\n",
       "          [ 8.3203e-01,  5.1953e-01,  3.3008e-01,  ..., -3.2422e-01,\n",
       "           -4.5703e-01, -1.6309e-01],\n",
       "          [ 6.4062e-01,  3.4570e-01,  7.5391e-01,  ..., -6.5430e-02,\n",
       "           -6.5918e-02, -2.5513e-02],\n",
       "          ...,\n",
       "          [-6.3672e-01,  5.3711e-02,  1.4648e-01,  ...,  8.7891e-02,\n",
       "            1.4141e+00,  1.1406e+00],\n",
       "          [-1.4062e-01, -2.6367e-01, -2.0386e-02,  ...,  2.5195e-01,\n",
       "            6.2500e-02,  6.0938e-01],\n",
       "          [-4.0527e-02, -3.7695e-01, -6.9922e-01,  ...,  1.4844e-01,\n",
       "           -4.1016e-01, -1.8555e-01]],\n",
       "\n",
       "         [[ 1.5030e-03, -2.3315e-02,  2.3041e-03,  ..., -9.7046e-03,\n",
       "           -2.5879e-02, -1.0071e-02],\n",
       "          [ 4.5703e-01,  7.4609e-01, -4.6289e-01,  ..., -6.0547e-02,\n",
       "            9.2773e-03,  3.7500e-01],\n",
       "          [-2.5146e-02, -2.3340e-01,  6.7969e-01,  ..., -1.0938e-01,\n",
       "            5.5847e-03, -8.5449e-03],\n",
       "          ...,\n",
       "          [ 6.3672e-01, -3.3008e-01,  5.1953e-01,  ...,  7.8613e-02,\n",
       "            6.9922e-01, -4.1992e-01],\n",
       "          [-2.6953e-01,  6.3281e-01,  5.8984e-01,  ..., -1.6699e-01,\n",
       "            6.5234e-01, -3.9062e-01],\n",
       "          [ 2.0215e-01, -1.0254e-01, -1.1328e-01,  ...,  2.4658e-02,\n",
       "           -2.4902e-01, -7.0801e-02]],\n",
       "\n",
       "         [[ 8.3008e-03,  5.0354e-03, -1.5640e-03,  ...,  4.9561e-02,\n",
       "           -9.0408e-04, -1.3428e-03],\n",
       "          [-2.6953e-01,  2.6953e-01, -5.5859e-01,  ..., -2.8711e-01,\n",
       "           -7.9297e-01,  2.1973e-01],\n",
       "          [-2.0117e-01,  3.8672e-01, -2.2559e-01,  ..., -5.5859e-01,\n",
       "           -5.7812e-01,  1.8066e-01],\n",
       "          ...,\n",
       "          [ 1.1963e-01,  3.6133e-02, -1.8359e-01,  ..., -3.0396e-02,\n",
       "           -4.5117e-01, -2.5391e-01],\n",
       "          [ 3.9648e-01, -2.2266e-01,  9.1309e-02,  ...,  1.9165e-02,\n",
       "           -3.2422e-01, -3.0859e-01],\n",
       "          [ 4.1797e-01, -1.7334e-02, -3.0664e-01,  ..., -9.6191e-02,\n",
       "           -1.0625e+00, -4.2578e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.5215e-03, -2.9907e-03,  7.4219e-02,  ...,  4.3030e-03,\n",
       "           -5.1270e-02, -1.3611e-02],\n",
       "          [ 2.0410e-01,  6.2109e-01, -4.1992e-02,  ..., -9.5703e-02,\n",
       "            2.4219e-01,  4.1016e-01],\n",
       "          [ 2.9492e-01,  2.8320e-01, -5.1953e-01,  ..., -1.1572e-01,\n",
       "            9.8047e-01,  1.2988e-01],\n",
       "          ...,\n",
       "          [ 9.0234e-01, -1.6211e-01, -4.7852e-01,  ..., -1.6602e-01,\n",
       "            6.6406e-01,  7.6562e-01],\n",
       "          [ 4.9414e-01,  1.6797e-01,  1.9629e-01,  ..., -8.1250e-01,\n",
       "            2.2656e-01,  6.3281e-01],\n",
       "          [ 8.8281e-01, -1.7480e-01, -1.8457e-01,  ..., -7.4219e-01,\n",
       "            1.0889e-01,  2.3828e-01]],\n",
       "\n",
       "         [[-5.7068e-03,  4.8218e-03, -1.2512e-02,  ...,  2.6489e-02,\n",
       "           -8.0566e-03,  3.6621e-03],\n",
       "          [ 4.6387e-02, -6.7578e-01, -2.0874e-02,  ...,  1.0781e+00,\n",
       "            8.7500e-01,  1.8262e-01],\n",
       "          [ 2.3730e-01, -6.7969e-01, -3.8086e-01,  ...,  1.0391e+00,\n",
       "            6.9141e-01,  4.3945e-01],\n",
       "          ...,\n",
       "          [ 2.9297e-01, -1.1094e+00,  4.9609e-01,  ...,  8.5449e-02,\n",
       "           -1.0234e+00,  7.0703e-01],\n",
       "          [-1.4832e-02, -8.7891e-01,  4.5898e-01,  ...,  7.5000e-01,\n",
       "           -1.2031e+00,  2.2461e-01],\n",
       "          [ 3.8867e-01, -4.4727e-01, -3.7695e-01,  ...,  3.5547e-01,\n",
       "            8.8867e-02,  4.0234e-01]],\n",
       "\n",
       "         [[ 8.7280e-03, -5.3711e-03,  7.1411e-03,  ..., -5.1575e-03,\n",
       "            1.5259e-02, -3.0518e-03],\n",
       "          [ 7.7344e-01,  2.2949e-01,  9.0234e-01,  ...,  6.6406e-01,\n",
       "            7.3047e-01, -6.2109e-01],\n",
       "          [ 4.9609e-01, -9.0820e-02,  6.0547e-01,  ...,  6.5234e-01,\n",
       "            4.7852e-01, -3.3984e-01],\n",
       "          ...,\n",
       "          [-4.6094e-01,  9.0332e-02,  1.4062e-01,  ..., -4.7852e-01,\n",
       "           -7.8125e-01,  7.1094e-01],\n",
       "          [-9.1016e-01,  4.9414e-01,  7.5391e-01,  ..., -4.0039e-01,\n",
       "           -5.6250e-01,  9.8828e-01],\n",
       "          [-6.7578e-01,  2.3340e-01, -1.7188e-01,  ..., -1.4355e-01,\n",
       "           -6.7578e-01,  8.6719e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-6.8848e-02,  2.3804e-02, -7.7515e-03,  ..., -1.8799e-02,\n",
       "            2.3145e-01, -1.3828e+00],\n",
       "          [ 9.0820e-02, -2.1582e-01, -2.5391e-02,  ...,  1.2266e+00,\n",
       "           -1.1523e-01,  7.0000e+00],\n",
       "          [ 8.0566e-02,  7.7637e-02,  1.6602e-01,  ...,  1.4688e+00,\n",
       "           -6.2109e-01,  6.5000e+00],\n",
       "          ...,\n",
       "          [-2.8516e-01,  4.4189e-02, -8.0078e-02,  ...,  1.7031e+00,\n",
       "           -1.4219e+00,  4.3438e+00],\n",
       "          [ 6.0156e-01,  1.2695e-01, -7.0312e-02,  ...,  1.7422e+00,\n",
       "           -2.0781e+00,  3.4375e+00],\n",
       "          [ 5.4688e-01, -1.9629e-01,  2.1094e-01,  ...,  1.0312e+00,\n",
       "            1.5793e-03,  3.4844e+00]],\n",
       "\n",
       "         [[ 4.2725e-02, -1.6174e-03,  3.4424e-02,  ..., -1.3672e-01,\n",
       "            1.8672e+00, -8.4473e-02],\n",
       "          [ 9.4727e-02, -1.8750e-01, -2.9297e-02,  ..., -6.0547e-01,\n",
       "           -4.2500e+00,  1.7578e+00],\n",
       "          [-1.1719e-02, -1.4062e-01,  3.1055e-01,  ..., -3.2422e-01,\n",
       "           -5.4062e+00,  6.6016e-01],\n",
       "          ...,\n",
       "          [-2.3193e-02, -5.8594e-03,  2.8906e-01,  ...,  6.7969e-01,\n",
       "           -5.5312e+00, -1.0625e+00],\n",
       "          [-2.2070e-01, -1.5625e-01,  1.6602e-01,  ...,  1.3574e-01,\n",
       "           -5.4688e+00, -7.3828e-01],\n",
       "          [ 5.1270e-03,  1.4355e-01, -3.0469e-01,  ...,  3.1494e-02,\n",
       "           -3.6562e+00, -2.3750e+00]],\n",
       "\n",
       "         [[-9.3994e-03, -1.9043e-02,  1.0620e-02,  ...,  6.8848e-02,\n",
       "           -5.3906e-01,  7.5000e-01],\n",
       "          [ 1.2695e-02, -6.7188e-01,  7.5195e-02,  ...,  1.5391e+00,\n",
       "            3.8477e-01, -1.4453e+00],\n",
       "          [ 9.7656e-01, -4.2578e-01,  9.5703e-02,  ...,  1.2109e-01,\n",
       "            4.9219e-01, -1.3984e+00],\n",
       "          ...,\n",
       "          [ 6.2500e-02, -8.7109e-01, -5.5469e-01,  ...,  9.0234e-01,\n",
       "            3.6914e-01, -2.3750e+00],\n",
       "          [ 3.2812e-01, -2.4805e-01, -4.6094e-01,  ...,  1.5078e+00,\n",
       "           -1.1797e+00, -2.1094e+00],\n",
       "          [ 8.0078e-01,  1.8945e-01,  7.5781e-01,  ..., -3.4180e-01,\n",
       "           -1.7285e-01, -2.1719e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.5391e-02, -4.0283e-02,  2.9297e-03,  ..., -6.0938e-01,\n",
       "           -6.9824e-02, -3.2812e-01],\n",
       "          [ 2.8125e-01,  6.8359e-02, -4.6484e-01,  ...,  1.3125e+00,\n",
       "           -4.4688e+00, -2.7344e+00],\n",
       "          [ 4.7266e-01,  1.6602e-01, -3.5889e-02,  ...,  2.1719e+00,\n",
       "           -3.6094e+00, -1.6797e+00],\n",
       "          ...,\n",
       "          [-4.0771e-02,  9.7656e-02,  4.1211e-01,  ...,  3.2500e+00,\n",
       "           -8.1641e-01, -1.3125e+00],\n",
       "          [-2.1387e-01,  4.4922e-02,  4.2969e-01,  ...,  1.2344e+00,\n",
       "           -1.2031e+00, -4.8438e-01],\n",
       "          [ 1.4551e-01, -9.2773e-02,  1.5430e-01,  ...,  1.1797e+00,\n",
       "           -6.4062e-01,  7.9297e-01]],\n",
       "\n",
       "         [[ 3.4424e-02,  7.1716e-03, -9.7046e-03,  ...,  3.7109e-01,\n",
       "           -1.1094e+00,  9.0234e-01],\n",
       "          [-1.6211e-01, -2.1191e-01,  1.1182e-01,  ..., -5.2188e+00,\n",
       "            1.8281e+00, -2.5156e+00],\n",
       "          [-3.7109e-01, -4.2969e-01, -9.1406e-01,  ..., -4.3750e+00,\n",
       "            1.5547e+00, -2.9531e+00],\n",
       "          ...,\n",
       "          [ 3.4180e-02, -7.7148e-02,  2.3828e-01,  ..., -9.0234e-01,\n",
       "            5.5938e+00,  4.1211e-01],\n",
       "          [-1.9043e-02,  2.0996e-02,  4.1406e-01,  ...,  1.0938e+00,\n",
       "            3.2656e+00, -2.0781e+00],\n",
       "          [ 1.9434e-01, -2.2852e-01, -3.7109e-01,  ...,  3.5352e-01,\n",
       "            1.2656e+00,  1.0781e+00]],\n",
       "\n",
       "         [[-6.6223e-03, -1.0925e-02, -2.4261e-03,  ..., -1.8066e-01,\n",
       "           -4.7461e-01,  1.2266e+00],\n",
       "          [ 3.9062e-01,  7.4219e-02, -5.2734e-02,  ...,  2.5000e+00,\n",
       "           -1.5625e+00, -3.4531e+00],\n",
       "          [-6.8359e-02,  3.0859e-01,  1.6602e-02,  ...,  2.2969e+00,\n",
       "           -4.8828e-01, -3.4375e+00],\n",
       "          ...,\n",
       "          [ 3.9795e-02,  3.5352e-01,  1.5137e-01,  ..., -6.7188e-01,\n",
       "            2.0469e+00, -6.3125e+00],\n",
       "          [-2.3145e-01, -5.5469e-01,  9.6094e-01,  ..., -9.8438e-01,\n",
       "            9.9219e-01, -5.4062e+00],\n",
       "          [-1.6016e-01,  3.3984e-01, -3.2031e-01,  ...,  8.5547e-01,\n",
       "            1.3203e+00, -3.3594e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 5.3406e-03, -1.1658e-02, -5.8594e-03,  ...,  5.2795e-03,\n",
       "            3.8300e-03, -1.4267e-03],\n",
       "          [-8.3496e-02,  4.2383e-01,  4.8828e-02,  ...,  8.5547e-01,\n",
       "            4.7607e-02,  5.1172e-01],\n",
       "          [-2.8516e-01,  7.1094e-01,  4.1406e-01,  ...,  7.6953e-01,\n",
       "            1.5991e-02,  5.2734e-01],\n",
       "          ...,\n",
       "          [-1.7676e-01,  3.6621e-02,  1.1621e-01,  ..., -6.5234e-01,\n",
       "            3.4961e-01,  8.7891e-02],\n",
       "          [-4.2969e-01, -1.5137e-01,  1.2598e-01,  ..., -7.8125e-01,\n",
       "            6.8359e-01,  1.2354e-01],\n",
       "          [ 2.1582e-01,  1.0234e+00,  1.1484e+00,  ...,  7.2266e-01,\n",
       "           -2.3145e-01,  4.6289e-01]],\n",
       "\n",
       "         [[-1.4038e-02, -1.2970e-03, -8.1787e-03,  ...,  2.4658e-02,\n",
       "            2.5482e-03, -6.0425e-03],\n",
       "          [-8.4766e-01,  4.1602e-01, -1.0352e-01,  ..., -3.2227e-01,\n",
       "            6.8750e-01, -7.5391e-01],\n",
       "          [-3.5742e-01, -4.8340e-02, -1.0156e+00,  ..., -3.6719e-01,\n",
       "            2.4609e-01, -2.6562e-01],\n",
       "          ...,\n",
       "          [-2.5781e-01,  5.7422e-01,  6.2988e-02,  ...,  9.6191e-02,\n",
       "            6.4453e-01, -1.2793e-01],\n",
       "          [-2.6367e-01,  3.2031e-01, -3.8281e-01,  ..., -4.0820e-01,\n",
       "            8.4766e-01, -1.9531e-01],\n",
       "          [ 5.2185e-03,  4.9805e-01,  2.1240e-02,  ..., -2.7008e-03,\n",
       "            3.6523e-01,  2.0020e-01]],\n",
       "\n",
       "         [[-3.3722e-03,  4.0283e-03,  1.5411e-03,  ...,  1.4801e-03,\n",
       "            2.4109e-03,  7.8735e-03],\n",
       "          [ 5.5078e-01,  9.6484e-01, -2.0312e-01,  ...,  1.4877e-03,\n",
       "            5.4688e-01,  1.9238e-01],\n",
       "          [ 8.0078e-01,  7.4219e-01, -1.4941e-01,  ...,  2.1680e-01,\n",
       "           -3.9844e-01,  1.6309e-01],\n",
       "          ...,\n",
       "          [-1.9336e-01, -5.9375e-01,  1.0625e+00,  ..., -6.6406e-01,\n",
       "           -2.4316e-01,  4.7266e-01],\n",
       "          [-9.9609e-01,  5.7031e-01,  4.8242e-01,  ..., -6.8750e-01,\n",
       "           -4.1211e-01,  8.2520e-02],\n",
       "          [-4.2188e-01,  4.8438e-01,  1.6309e-01,  ..., -7.1875e-01,\n",
       "           -2.6953e-01, -2.9492e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.6926e-03, -1.0834e-03, -7.5989e-03,  ..., -3.6774e-03,\n",
       "           -1.8845e-03, -4.2419e-03],\n",
       "          [ 1.6406e-01, -7.4219e-02,  3.2617e-01,  ..., -6.4453e-01,\n",
       "           -3.8672e-01,  1.1641e+00],\n",
       "          [-9.0820e-02, -3.4766e-01, -2.2949e-01,  ..., -5.4297e-01,\n",
       "           -2.6001e-02,  9.2578e-01],\n",
       "          ...,\n",
       "          [ 3.8086e-02,  3.8086e-02, -1.5820e-01,  ...,  7.8516e-01,\n",
       "           -1.1865e-01, -2.7539e-01],\n",
       "          [ 8.8281e-01,  3.1836e-01,  1.0791e-01,  ...,  3.9648e-01,\n",
       "            5.7129e-02,  7.7148e-02],\n",
       "          [-1.0596e-01,  2.5781e-01,  7.1484e-01,  ...,  6.4941e-02,\n",
       "           -2.3828e-01,  5.0049e-02]],\n",
       "\n",
       "         [[ 1.3428e-03, -5.1575e-03,  8.0566e-03,  ...,  5.3711e-03,\n",
       "           -5.3101e-03, -2.0874e-02],\n",
       "          [ 7.6172e-01,  2.9883e-01, -2.6172e-01,  ..., -6.1279e-02,\n",
       "            8.1250e-01,  5.6152e-02],\n",
       "          [ 5.0293e-02, -4.7266e-01, -2.6758e-01,  ...,  3.5645e-02,\n",
       "            4.6680e-01,  1.6602e-01],\n",
       "          ...,\n",
       "          [-1.0547e+00,  1.1172e+00,  2.5977e-01,  ..., -6.4062e-01,\n",
       "            5.8594e-01,  1.3125e+00],\n",
       "          [-1.2578e+00,  6.6016e-01, -5.6641e-01,  ..., -1.1328e-01,\n",
       "            4.4141e-01,  1.0859e+00],\n",
       "          [-8.0566e-02,  3.8086e-01,  1.5137e-01,  ..., -7.5781e-01,\n",
       "            6.2500e-01,  4.0430e-01]],\n",
       "\n",
       "         [[ 2.6093e-03,  2.6398e-03,  3.7537e-03,  ..., -1.0010e-02,\n",
       "           -2.3041e-03,  1.4725e-03],\n",
       "          [-3.0273e-01, -1.0938e+00,  1.2500e-01,  ...,  1.2451e-01,\n",
       "           -7.3438e-01,  5.2246e-02],\n",
       "          [-7.0703e-01, -1.2578e+00,  2.0801e-01,  ...,  5.3516e-01,\n",
       "           -4.2969e-01, -2.0605e-01],\n",
       "          ...,\n",
       "          [ 1.7500e+00,  6.6797e-01, -1.2188e+00,  ..., -8.3203e-01,\n",
       "           -2.4609e-01, -3.7109e-01],\n",
       "          [ 1.2656e+00,  5.6250e-01, -5.4297e-01,  ..., -2.0703e-01,\n",
       "           -1.4844e-01, -2.2339e-02],\n",
       "          [ 1.5234e+00, -2.1729e-02, -9.7656e-01,  ..., -1.9922e-01,\n",
       "           -1.9336e-01,  1.5039e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-5.0659e-03, -1.3184e-02, -6.3782e-03,  ...,  2.0312e-01,\n",
       "           -1.0400e-01, -2.0996e-01],\n",
       "          [-4.7852e-02, -7.5684e-02,  7.6172e-02,  ..., -1.6641e+00,\n",
       "            1.6172e+00,  2.8438e+00],\n",
       "          [ 5.4688e-01,  1.6602e-01,  3.4180e-01,  ..., -3.4180e-01,\n",
       "            1.7344e+00,  2.4844e+00],\n",
       "          ...,\n",
       "          [-2.4512e-01,  2.5195e-01, -2.5000e-01,  ...,  2.2344e+00,\n",
       "            1.3047e+00,  1.4062e+00],\n",
       "          [-2.2070e-01,  2.2266e-01, -8.3496e-02,  ...,  1.8828e+00,\n",
       "            1.3906e+00,  9.7656e-01],\n",
       "          [ 1.2109e-01,  2.7930e-01,  5.3711e-02,  ...,  2.9531e+00,\n",
       "            8.2422e-01,  2.5000e+00]],\n",
       "\n",
       "         [[ 2.9297e-02, -7.7515e-03,  4.0283e-03,  ...,  3.1445e-01,\n",
       "            1.0859e+00, -3.3203e-01],\n",
       "          [ 7.2266e-02, -1.8750e-01,  6.5234e-01,  ..., -3.0156e+00,\n",
       "           -3.9219e+00,  1.1016e+00],\n",
       "          [ 2.6953e-01, -2.1875e-01,  2.8125e-01,  ..., -1.8516e+00,\n",
       "           -4.2812e+00,  7.3828e-01],\n",
       "          ...,\n",
       "          [ 1.9531e-01,  1.1377e-01, -3.9648e-01,  ..., -3.1406e+00,\n",
       "           -3.5781e+00,  1.1172e+00],\n",
       "          [ 7.7148e-02,  1.9727e-01, -3.4180e-01,  ..., -4.5625e+00,\n",
       "           -4.1562e+00,  3.0078e-01],\n",
       "          [-2.0996e-02,  1.2939e-02,  7.8125e-03,  ..., -2.9062e+00,\n",
       "           -3.4688e+00, -5.3125e-01]],\n",
       "\n",
       "         [[-1.5625e-02, -8.4839e-03, -7.0801e-03,  ...,  4.5117e-01,\n",
       "            3.5156e-01, -3.8672e-01],\n",
       "          [-3.3203e-02, -4.7266e-01, -1.4746e-01,  ..., -1.6250e+00,\n",
       "            1.0132e-02, -3.3398e-01],\n",
       "          [-7.0312e-02,  7.4219e-01, -4.3359e-01,  ..., -1.1094e+00,\n",
       "           -5.7422e-01,  4.9805e-02],\n",
       "          ...,\n",
       "          [-4.1406e-01,  8.4375e-01,  8.3594e-01,  ..., -2.2656e+00,\n",
       "           -4.5703e-01,  1.0986e-01],\n",
       "          [-5.6250e-01,  9.3750e-02,  9.3750e-01,  ..., -1.8828e+00,\n",
       "           -1.9219e+00,  6.6797e-01],\n",
       "          [-1.6504e-01,  2.0703e-01,  2.7148e-01,  ..., -7.4219e-01,\n",
       "           -9.5703e-01,  3.3691e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.3367e-02,  1.0132e-02,  3.3875e-03,  ...,  1.7578e+00,\n",
       "           -3.0469e-01, -3.6328e-01],\n",
       "          [ 4.1406e-01, -4.1602e-01,  3.2617e-01,  ..., -4.9688e+00,\n",
       "           -6.6797e-01,  6.6797e-01],\n",
       "          [ 2.8320e-01, -2.5586e-01, -2.0215e-01,  ..., -5.7812e+00,\n",
       "            1.1035e-01,  8.2812e-01],\n",
       "          ...,\n",
       "          [-3.1445e-01, -2.3145e-01, -1.4258e-01,  ..., -6.4062e+00,\n",
       "           -7.4219e-01,  1.1016e+00],\n",
       "          [ 4.2969e-02, -3.9062e-01,  3.4180e-01,  ..., -6.5938e+00,\n",
       "            1.8203e+00,  1.2578e+00],\n",
       "          [ 8.3008e-02,  4.0283e-02, -2.4023e-01,  ..., -6.0938e+00,\n",
       "            7.1484e-01, -2.9297e-02]],\n",
       "\n",
       "         [[ 2.0020e-02, -1.8433e-02,  2.4414e-03,  ..., -9.8438e-01,\n",
       "           -3.2031e-01, -1.0986e-01],\n",
       "          [-5.7031e-01, -1.5820e-01,  5.9375e-01,  ...,  5.9375e+00,\n",
       "           -7.4609e-01,  2.3906e+00],\n",
       "          [-7.8516e-01,  3.1641e-01, -3.7891e-01,  ...,  6.9062e+00,\n",
       "           -3.8086e-01,  1.7344e+00],\n",
       "          ...,\n",
       "          [-1.6406e-01, -1.6211e-01, -2.9541e-02,  ...,  7.8438e+00,\n",
       "           -1.5000e+00,  2.3594e+00],\n",
       "          [-2.4414e-01,  6.5918e-02, -1.0791e-01,  ...,  6.8125e+00,\n",
       "           -1.0781e+00,  1.7422e+00],\n",
       "          [-2.4902e-01, -5.1172e-01,  6.5625e-01,  ...,  7.6250e+00,\n",
       "           -1.6016e+00, -1.2266e+00]],\n",
       "\n",
       "         [[ 2.6611e-02,  8.4839e-03,  2.0996e-02,  ...,  2.2705e-02,\n",
       "            1.1328e+00, -2.5146e-02],\n",
       "          [-1.8262e-01,  4.0430e-01,  6.7969e-01,  ...,  1.3281e+00,\n",
       "           -3.0469e+00,  1.6484e+00],\n",
       "          [ 2.9297e-02, -2.9297e-01,  3.9453e-01,  ...,  9.9609e-01,\n",
       "           -2.9844e+00,  2.7969e+00],\n",
       "          ...,\n",
       "          [ 4.1602e-01, -7.9102e-02, -1.1719e-02,  ...,  9.0234e-01,\n",
       "           -1.5625e-01,  4.0625e+00],\n",
       "          [-3.4180e-01,  2.8809e-02,  4.4141e-01,  ...,  1.2354e-01,\n",
       "           -1.0234e+00,  1.5391e+00],\n",
       "          [-3.8281e-01,  5.8203e-01,  2.5781e-01,  ..., -2.1484e-01,\n",
       "           -1.7090e-01,  2.4219e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.2268e-02, -1.3489e-02, -6.8970e-03,  ..., -1.3489e-02,\n",
       "           -1.7944e-02,  1.1292e-02],\n",
       "          [-2.2949e-01,  8.9355e-02, -2.6367e-01,  ..., -4.1602e-01,\n",
       "           -6.6016e-01, -2.9883e-01],\n",
       "          [-4.1406e-01,  3.6621e-02, -2.4707e-01,  ..., -4.4336e-01,\n",
       "           -1.7871e-01,  9.5703e-02],\n",
       "          ...,\n",
       "          [ 1.3672e-01,  1.6016e-01,  2.1094e-01,  ...,  1.9238e-01,\n",
       "           -3.2031e-01,  4.4434e-02],\n",
       "          [ 1.2695e-01, -4.5654e-02,  7.1484e-01,  ..., -5.6250e-01,\n",
       "            1.2146e-02,  6.9141e-01],\n",
       "          [-6.7383e-02,  1.0010e-01,  5.2344e-01,  ..., -5.4688e-01,\n",
       "            6.9141e-01,  2.1191e-01]],\n",
       "\n",
       "         [[ 7.0496e-03,  2.0386e-02, -1.6022e-03,  ...,  9.8877e-03,\n",
       "           -1.0132e-02,  1.0559e-02],\n",
       "          [-8.3008e-02, -8.6328e-01, -4.0625e-01,  ...,  1.7773e-01,\n",
       "            2.5195e-01,  4.6680e-01],\n",
       "          [-5.2734e-01, -2.8906e-01, -5.6641e-01,  ..., -3.3691e-02,\n",
       "           -4.6875e-01,  1.8555e-01],\n",
       "          ...,\n",
       "          [ 5.2490e-02,  3.9648e-01,  4.5117e-01,  ...,  3.2471e-02,\n",
       "            3.4180e-01,  1.4648e-01],\n",
       "          [ 8.3203e-01,  6.4844e-01,  1.2109e+00,  ..., -5.3125e-01,\n",
       "            6.8750e-01, -8.2031e-02],\n",
       "          [-1.2256e-01, -2.5195e-01,  5.1562e-01,  ..., -6.8750e-01,\n",
       "            7.1289e-02,  5.5859e-01]],\n",
       "\n",
       "         [[ 1.7334e-02,  1.6479e-03, -6.7749e-03,  ..., -8.2520e-02,\n",
       "            1.1921e-04, -6.9580e-03],\n",
       "          [-1.1016e+00,  1.9531e-01, -1.6113e-01,  ..., -1.1016e+00,\n",
       "            3.7109e-01,  2.0752e-02],\n",
       "          [-3.5742e-01, -4.7852e-02, -6.2891e-01,  ..., -4.2188e-01,\n",
       "            5.9375e-01,  8.3496e-02],\n",
       "          ...,\n",
       "          [-8.5938e-01, -1.5820e-01,  3.9648e-01,  ..., -7.6172e-01,\n",
       "            5.9814e-02, -2.7930e-01],\n",
       "          [-7.6172e-01, -1.7700e-02,  1.0469e+00,  ..., -8.3984e-01,\n",
       "            3.1055e-01, -6.5625e-01],\n",
       "          [ 1.6797e-01, -5.0000e-01, -4.4922e-01,  ..., -1.9434e-01,\n",
       "           -6.2109e-01, -1.0859e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5869e-02, -7.8735e-03, -3.1738e-03,  ...,  5.2795e-03,\n",
       "            2.1210e-03,  6.3171e-03],\n",
       "          [ 6.7969e-01, -9.1797e-02,  2.0020e-01,  ..., -3.3691e-02,\n",
       "            1.2656e+00, -1.2891e-01],\n",
       "          [ 2.4414e-01,  5.6641e-01, -1.9653e-02,  ..., -4.4531e-01,\n",
       "            6.3672e-01, -3.9258e-01],\n",
       "          ...,\n",
       "          [ 6.9531e-01,  3.8867e-01, -5.0293e-02,  ..., -2.0508e-01,\n",
       "            2.4219e-01,  6.2109e-01],\n",
       "          [ 5.7422e-01,  1.1797e+00,  2.2656e-01,  ...,  7.3242e-02,\n",
       "            6.3672e-01,  1.9141e-01],\n",
       "          [ 5.2344e-01, -1.8848e-01, -3.6719e-01,  ...,  6.1328e-01,\n",
       "           -1.0107e-01,  1.8262e-01]],\n",
       "\n",
       "         [[-1.2878e-02, -1.0010e-02, -5.8899e-03,  ..., -2.3804e-03,\n",
       "            3.8147e-03, -1.6235e-02],\n",
       "          [-5.8594e-01,  6.7578e-01, -3.6719e-01,  ...,  4.1406e-01,\n",
       "           -2.0117e-01,  1.5039e-01],\n",
       "          [ 7.0312e-02,  9.2578e-01, -3.4912e-02,  ...,  3.3203e-02,\n",
       "            4.3335e-03,  4.6094e-01],\n",
       "          ...,\n",
       "          [-3.7598e-02,  4.3164e-01,  1.6235e-02,  ...,  1.9629e-01,\n",
       "            4.1602e-01,  4.2188e-01],\n",
       "          [ 6.2500e-01,  8.0078e-01, -5.2734e-01,  ...,  2.9102e-01,\n",
       "            2.4121e-01,  3.3789e-01],\n",
       "          [-2.8320e-01,  2.9297e-01,  1.0391e+00,  ..., -9.4238e-02,\n",
       "           -4.4727e-01, -7.2266e-01]],\n",
       "\n",
       "         [[-1.1414e-02, -7.0190e-03, -8.9844e-02,  ...,  1.9989e-03,\n",
       "            5.1117e-04, -4.1504e-03],\n",
       "          [ 2.0508e-01, -6.2500e-01,  4.8633e-01,  ..., -2.8906e-01,\n",
       "            2.0703e-01,  6.6797e-01],\n",
       "          [ 1.9336e-01, -5.5078e-01,  6.2109e-01,  ..., -4.4727e-01,\n",
       "           -2.4414e-01,  4.3750e-01],\n",
       "          ...,\n",
       "          [-6.2500e-01, -6.4062e-01,  1.9688e+00,  ..., -1.9844e+00,\n",
       "            3.0975e-03, -4.7656e-01],\n",
       "          [-7.0703e-01, -1.2158e-01,  1.8281e+00,  ..., -1.4688e+00,\n",
       "            5.9766e-01, -3.2422e-01],\n",
       "          [-1.8750e-01, -9.2578e-01,  1.2812e+00,  ..., -8.9453e-01,\n",
       "           -7.4219e-02,  3.6328e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 8.1177e-03, -4.6997e-03,  1.5198e-02,  ..., -2.7539e-01,\n",
       "            3.9062e-02, -5.0000e-01],\n",
       "          [-3.6914e-01,  8.7109e-01,  4.5898e-01,  ..., -9.0234e-01,\n",
       "           -5.7812e-01,  5.5469e-01],\n",
       "          [ 9.2188e-01,  4.7461e-01,  4.2188e-01,  ...,  1.7188e-01,\n",
       "            2.0996e-02,  4.9805e-01],\n",
       "          ...,\n",
       "          [-9.2188e-01,  2.9297e-01,  1.1094e+00,  ...,  7.6953e-01,\n",
       "            2.0312e+00, -1.6602e-01],\n",
       "          [-1.1016e+00, -1.8164e-01,  7.6562e-01,  ...,  1.6406e+00,\n",
       "            2.0625e+00, -3.9795e-02],\n",
       "          [-8.5938e-02,  1.1172e+00,  4.8828e-01,  ...,  2.7188e+00,\n",
       "            1.6953e+00,  6.6016e-01]],\n",
       "\n",
       "         [[-1.6724e-02, -9.3994e-03,  1.9287e-02,  ...,  3.1128e-02,\n",
       "            3.0469e-01, -5.0781e-02],\n",
       "          [-1.1250e+00, -4.5703e-01, -2.3438e-01,  ...,  3.8867e-01,\n",
       "            6.4453e-01, -4.0625e-01],\n",
       "          [-4.1602e-01,  3.1250e-01, -3.4180e-01,  ...,  1.1182e-01,\n",
       "            7.8516e-01, -9.9219e-01],\n",
       "          ...,\n",
       "          [ 1.2207e-01, -8.2812e-01,  2.0508e-01,  ..., -1.2344e+00,\n",
       "           -1.0859e+00, -1.0000e+00],\n",
       "          [-3.3203e-02, -1.5000e+00,  8.7891e-01,  ...,  1.1797e+00,\n",
       "           -1.4922e+00, -3.3936e-02],\n",
       "          [-5.1562e-01, -7.5781e-01, -2.0215e-01,  ..., -1.4844e+00,\n",
       "           -1.6504e-01,  3.7500e-01]],\n",
       "\n",
       "         [[-1.1215e-03,  2.2827e-02, -7.2937e-03,  ...,  3.3203e-02,\n",
       "            1.4531e+00, -5.7129e-02],\n",
       "          [ 1.9375e+00,  1.2031e+00, -1.2578e+00,  ..., -6.6016e-01,\n",
       "           -3.5156e+00, -3.1641e-01],\n",
       "          [ 2.4219e-01,  4.7070e-01, -3.5547e-01,  ...,  2.8711e-01,\n",
       "           -4.0312e+00,  8.3594e-01],\n",
       "          ...,\n",
       "          [-1.3770e-01,  6.6406e-01, -2.0215e-01,  ...,  7.4219e-01,\n",
       "           -4.5312e+00,  6.9141e-01],\n",
       "          [ 3.9258e-01, -5.3516e-01, -5.5469e-01,  ..., -3.0859e-01,\n",
       "           -4.2812e+00, -4.4141e-01],\n",
       "          [ 5.5859e-01,  8.5156e-01, -1.4062e+00,  ...,  9.5703e-01,\n",
       "           -4.7500e+00,  1.9336e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 7.1716e-03,  3.9795e-02,  1.0742e-02,  ...,  4.0625e-01,\n",
       "           -1.4160e-01, -3.6328e-01],\n",
       "          [-2.5391e-01, -5.8594e-01,  1.4551e-01,  ...,  5.5469e-01,\n",
       "           -3.4180e-01,  2.2168e-01],\n",
       "          [-6.3281e-01, -6.1719e-01,  5.1172e-01,  ..., -7.3438e-01,\n",
       "            9.2578e-01,  1.6328e+00],\n",
       "          ...,\n",
       "          [ 1.2422e+00, -3.3203e-02,  9.5312e-01,  ..., -4.2969e-01,\n",
       "           -9.1309e-02,  2.5000e+00],\n",
       "          [ 7.5684e-02, -2.2656e-01, -2.4023e-01,  ..., -1.4453e-01,\n",
       "           -3.3281e+00,  1.5625e+00],\n",
       "          [-5.6641e-01,  3.5742e-01, -1.7578e-01,  ...,  2.7148e-01,\n",
       "            3.2812e-01,  1.7500e+00]],\n",
       "\n",
       "         [[-1.0315e-02,  3.1738e-02,  1.8311e-02,  ...,  2.1250e+00,\n",
       "            2.2754e-01,  5.1562e-01],\n",
       "          [ 2.2461e-01, -5.3125e-01, -1.1562e+00,  ..., -3.2031e+00,\n",
       "            4.1992e-01, -4.9688e+00],\n",
       "          [ 2.2461e-02, -4.2188e-01, -4.3750e-01,  ..., -4.6562e+00,\n",
       "           -7.4609e-01, -4.6250e+00],\n",
       "          ...,\n",
       "          [-3.6328e-01,  4.6484e-01,  3.5156e-02,  ..., -4.5938e+00,\n",
       "           -8.7891e-01, -1.9238e-01],\n",
       "          [ 2.1973e-01,  4.6680e-01,  1.0312e+00,  ..., -5.0938e+00,\n",
       "            5.7422e-01, -4.6875e-01],\n",
       "          [ 4.9805e-01,  4.1016e-02,  4.4141e-01,  ..., -3.5000e+00,\n",
       "            1.1016e+00, -1.9922e+00]],\n",
       "\n",
       "         [[ 3.7842e-02,  5.4199e-02, -3.0518e-02,  ..., -1.1641e+00,\n",
       "           -2.0508e-01,  3.1055e-01],\n",
       "          [-1.8262e-01, -5.0781e-01, -1.6113e-01,  ...,  1.9297e+00,\n",
       "           -2.3828e-01,  3.3984e-01],\n",
       "          [-4.2188e-01, -4.1992e-01,  1.9336e-01,  ...,  2.3750e+00,\n",
       "            1.0781e+00,  7.4219e-01],\n",
       "          ...,\n",
       "          [-1.8652e-01, -1.1035e-01, -7.5684e-02,  ...,  2.5156e+00,\n",
       "           -3.5938e-01, -1.1797e+00],\n",
       "          [ 6.6895e-02, -2.6562e-01, -5.4199e-02,  ...,  2.5156e+00,\n",
       "            4.1211e-01, -6.0938e-01],\n",
       "          [ 2.6953e-01,  2.9102e-01, -3.9062e-03,  ...,  2.5938e+00,\n",
       "            8.9844e-01,  4.7656e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 8.9722e-03,  3.4790e-03, -1.5625e-02,  ...,  8.0566e-03,\n",
       "            6.1646e-03, -6.3171e-03],\n",
       "          [ 1.9727e-01, -8.3008e-02, -2.0117e-01,  ..., -9.5312e-01,\n",
       "            1.0312e+00,  6.8054e-03],\n",
       "          [ 4.8584e-02,  4.8584e-02, -1.3770e-01,  ..., -4.3164e-01,\n",
       "            4.2188e-01,  3.3594e-01],\n",
       "          ...,\n",
       "          [-1.2734e+00, -3.8672e-01,  3.4766e-01,  ...,  5.4688e-01,\n",
       "            3.1128e-02, -2.0703e-01],\n",
       "          [-8.7500e-01,  2.0508e-01,  8.5547e-01,  ...,  3.0762e-02,\n",
       "            8.4961e-02, -5.1953e-01],\n",
       "          [-2.3438e-01, -6.7444e-03,  5.7031e-01,  ...,  8.7891e-02,\n",
       "           -7.1094e-01, -1.0693e-01]],\n",
       "\n",
       "         [[-1.0071e-02,  2.9907e-02,  5.8899e-03,  ..., -3.7384e-03,\n",
       "           -2.3956e-03, -1.4832e-02],\n",
       "          [ 1.2061e-01, -4.0430e-01,  1.6602e-01,  ..., -4.1602e-01,\n",
       "            4.6875e-01, -3.4766e-01],\n",
       "          [ 4.5898e-01, -3.8672e-01,  5.3125e-01,  ..., -4.2578e-01,\n",
       "            6.3672e-01, -1.1719e-01],\n",
       "          ...,\n",
       "          [-6.5234e-01,  2.0801e-01,  2.5586e-01,  ...,  8.4766e-01,\n",
       "           -4.1602e-01, -7.1777e-02],\n",
       "          [-2.8442e-02,  3.1641e-01, -6.0730e-03,  ...,  8.6719e-01,\n",
       "           -6.2109e-01,  3.1836e-01],\n",
       "          [ 1.9434e-01,  2.7930e-01,  2.3340e-01,  ..., -2.7734e-01,\n",
       "           -2.0215e-01,  7.8906e-01]],\n",
       "\n",
       "         [[ 1.0803e-02,  1.7334e-02, -1.6235e-02,  ...,  1.3611e-02,\n",
       "            9.3994e-03,  2.7588e-02],\n",
       "          [-7.0801e-02,  4.8828e-01, -2.4805e-01,  ...,  9.5312e-01,\n",
       "            4.3213e-02, -7.4219e-01],\n",
       "          [-1.7480e-01,  5.7031e-01,  4.5312e-01,  ...,  9.4531e-01,\n",
       "           -5.7129e-02, -6.3672e-01],\n",
       "          ...,\n",
       "          [-1.3516e+00,  3.2031e-01,  1.0625e+00,  ...,  1.3906e+00,\n",
       "            1.1523e-01,  1.2793e-01],\n",
       "          [-1.1094e+00,  2.2363e-01,  7.8125e-01,  ...,  8.2422e-01,\n",
       "           -1.2500e-01,  4.4434e-02],\n",
       "          [-9.5703e-01, -5.5859e-01,  1.7812e+00,  ...,  1.9375e+00,\n",
       "            2.1875e-01, -6.6223e-03]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-2.6001e-02, -1.1475e-02, -7.4158e-03,  ..., -1.0559e-02,\n",
       "           -2.0752e-03,  9.6436e-03],\n",
       "          [-8.9062e-01, -3.5156e-02, -1.8555e-01,  ...,  6.3672e-01,\n",
       "           -8.9453e-01,  1.0889e-01],\n",
       "          [-1.1953e+00, -5.2734e-01, -3.9648e-01,  ...,  6.5625e-01,\n",
       "           -3.6328e-01,  2.3340e-01],\n",
       "          ...,\n",
       "          [ 6.8750e-01,  3.7891e-01,  4.8438e-01,  ...,  7.9688e-01,\n",
       "            5.4688e-01,  2.8711e-01],\n",
       "          [ 5.8203e-01,  1.2656e+00,  1.9141e+00,  ...,  2.3193e-02,\n",
       "            2.6245e-02, -5.9766e-01],\n",
       "          [-6.4062e-01, -6.1328e-01, -1.4258e-01,  ...,  6.6895e-02,\n",
       "            8.9355e-02,  1.1816e-01]],\n",
       "\n",
       "         [[-9.2697e-04, -3.5400e-03,  6.2561e-03,  ..., -3.0640e-02,\n",
       "            1.0742e-02, -8.6594e-04],\n",
       "          [ 2.8198e-02, -4.0430e-01, -8.1250e-01,  ...,  2.1973e-01,\n",
       "           -4.8584e-02, -1.0693e-01],\n",
       "          [ 4.7119e-02, -3.0664e-01, -8.6719e-01,  ...,  2.2168e-01,\n",
       "            1.8066e-01,  1.0254e-01],\n",
       "          ...,\n",
       "          [ 2.4512e-01, -3.5352e-01,  5.3906e-01,  ...,  2.5391e-01,\n",
       "           -1.4551e-01,  3.1250e-02],\n",
       "          [ 1.1797e+00, -7.5781e-01,  7.1875e-01,  ..., -2.3340e-01,\n",
       "            4.6094e-01,  1.0547e+00],\n",
       "          [ 1.6968e-02, -1.2422e+00,  2.7148e-01,  ..., -6.6016e-01,\n",
       "            2.0801e-01,  1.7969e-01]],\n",
       "\n",
       "         [[ 8.1177e-03,  7.1716e-03,  1.1353e-02,  ...,  1.1963e-02,\n",
       "            1.0925e-02, -1.2024e-02],\n",
       "          [ 2.0605e-01, -1.3594e+00, -3.9844e-01,  ..., -1.3086e-01,\n",
       "            5.3125e-01, -2.5977e-01],\n",
       "          [ 7.7344e-01, -7.5781e-01, -2.2949e-01,  ..., -7.9297e-01,\n",
       "           -4.1406e-01, -2.4316e-01],\n",
       "          ...,\n",
       "          [ 2.7539e-01, -3.3264e-03, -3.3594e-01,  ..., -3.2031e-01,\n",
       "            6.8359e-03,  5.3125e-01],\n",
       "          [-2.2754e-01, -1.0010e-01, -4.7852e-01,  ..., -5.3516e-01,\n",
       "           -2.6172e-01,  6.5625e-01],\n",
       "          [ 7.5195e-02, -9.6484e-01, -3.3008e-01,  ..., -3.9258e-01,\n",
       "            3.9258e-01, -3.4180e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.9673e-03,  6.9580e-03,  1.5198e-02,  ..., -1.8555e-01,\n",
       "           -1.9684e-03, -7.7734e-01],\n",
       "          [-5.0781e-01,  1.2451e-01,  4.3359e-01,  ...,  7.9688e-01,\n",
       "            1.3359e+00,  2.5156e+00],\n",
       "          [ 1.9434e-01,  3.7500e-01,  4.0234e-01,  ..., -1.9336e-01,\n",
       "            1.8047e+00,  2.7031e+00],\n",
       "          ...,\n",
       "          [ 2.3242e-01,  1.0986e-01, -4.3945e-01,  ...,  3.3789e-01,\n",
       "            8.9062e-01,  1.6250e+00],\n",
       "          [-7.8125e-03,  5.2344e-01, -1.1865e-01,  ...,  2.0781e+00,\n",
       "            4.5898e-01,  9.9609e-01],\n",
       "          [-2.6953e-01,  5.3125e-01, -4.6094e-01,  ...,  3.2227e-01,\n",
       "            1.3906e+00,  2.4707e-01]],\n",
       "\n",
       "         [[-1.6479e-02, -2.6978e-02,  1.9775e-02,  ...,  3.1836e-01,\n",
       "            4.3945e-01,  3.3008e-01],\n",
       "          [-5.9375e-01, -1.4453e-01,  2.3047e-01,  ...,  2.9492e-01,\n",
       "            1.0449e-01, -6.1719e-01],\n",
       "          [ 2.3633e-01,  1.6602e-02, -6.5918e-02,  ...,  3.8672e-01,\n",
       "            4.5117e-01, -1.0156e+00],\n",
       "          ...,\n",
       "          [-1.8164e-01, -1.5625e-02,  3.2227e-02,  ...,  6.2891e-01,\n",
       "           -3.7695e-01, -7.6172e-01],\n",
       "          [-1.3984e+00,  4.7266e-01,  5.8984e-01,  ...,  3.3789e-01,\n",
       "           -5.3516e-01,  1.9141e-01],\n",
       "          [-1.3125e+00,  8.6328e-01,  1.2988e-01,  ..., -8.8672e-01,\n",
       "           -8.1641e-01, -2.6953e-01]],\n",
       "\n",
       "         [[-5.4016e-03, -1.9836e-03, -3.9978e-03,  ..., -3.1641e-01,\n",
       "            5.5176e-02,  1.9434e-01],\n",
       "          [-2.4512e-01, -4.9609e-01, -7.7734e-01,  ...,  3.7891e-01,\n",
       "            9.2188e-01, -2.7466e-02],\n",
       "          [ 3.6328e-01, -5.5078e-01,  3.5938e-01,  ...,  2.8906e-01,\n",
       "           -1.4160e-01,  6.3672e-01],\n",
       "          ...,\n",
       "          [-1.4648e-02,  2.6562e-01,  5.4688e-01,  ..., -1.4141e+00,\n",
       "           -1.2734e+00, -8.7109e-01],\n",
       "          [ 5.8203e-01,  2.9297e-02, -5.1953e-01,  ...,  8.7891e-01,\n",
       "            6.1035e-02, -7.6953e-01],\n",
       "          [-4.1504e-02,  5.4297e-01, -8.2031e-01,  ...,  9.2188e-01,\n",
       "           -9.9219e-01, -3.9258e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-3.5400e-02,  2.9907e-03, -4.9591e-04,  ...,  7.6172e-02,\n",
       "            2.6172e-01, -2.8809e-02],\n",
       "          [-1.1719e-02, -5.8203e-01,  8.0859e-01,  ..., -9.1406e-01,\n",
       "           -3.8574e-02, -2.2344e+00],\n",
       "          [-2.9492e-01, -9.7656e-01,  1.0547e+00,  ..., -2.0605e-01,\n",
       "           -3.5938e-01, -2.7969e+00],\n",
       "          ...,\n",
       "          [-6.4844e-01, -2.0898e-01, -1.4160e-01,  ..., -4.0234e-01,\n",
       "            9.4727e-02, -1.2188e+00],\n",
       "          [-4.3945e-01, -1.1250e+00, -2.0312e+00,  ..., -1.4609e+00,\n",
       "            1.5312e+00,  5.9766e-01],\n",
       "          [ 1.2402e-01,  8.6328e-01, -5.2344e-01,  ..., -4.0430e-01,\n",
       "            4.0820e-01, -1.3906e+00]],\n",
       "\n",
       "         [[-2.4658e-02, -9.0332e-03,  3.1128e-02,  ..., -1.9824e-01,\n",
       "            5.1953e-01,  2.8809e-02],\n",
       "          [ 1.0352e-01, -3.9062e-01, -5.3906e-01,  ...,  1.2656e+00,\n",
       "           -5.5078e-01, -2.5781e+00],\n",
       "          [ 3.7891e-01, -6.0938e-01, -6.2109e-01,  ...,  1.4453e+00,\n",
       "           -3.7500e-01, -1.1562e+00],\n",
       "          ...,\n",
       "          [-7.3853e-03,  3.4766e-01, -1.3867e-01,  ...,  7.6562e-01,\n",
       "           -6.6406e-01, -7.2656e-01],\n",
       "          [-1.8555e-02,  3.3594e-01, -6.2012e-02,  ...,  1.4844e+00,\n",
       "           -1.7500e+00,  4.0039e-01],\n",
       "          [-4.2969e-02,  4.0039e-01,  3.4766e-01,  ...,  1.7969e+00,\n",
       "           -1.5625e-01, -9.6875e-01]],\n",
       "\n",
       "         [[ 9.4604e-03, -4.1504e-03,  3.2959e-02,  ...,  3.4766e-01,\n",
       "            1.0469e+00,  4.3164e-01],\n",
       "          [-8.4961e-02, -3.7109e-02, -2.8076e-02,  ..., -1.2891e+00,\n",
       "           -1.9922e+00,  4.3164e-01],\n",
       "          [-1.6797e-01, -1.4771e-02, -1.3477e-01,  ..., -9.0234e-01,\n",
       "           -2.4531e+00, -9.1406e-01],\n",
       "          ...,\n",
       "          [-4.3750e-01,  3.7695e-01, -1.6797e-01,  ..., -3.2031e-01,\n",
       "           -1.8125e+00, -2.2656e+00],\n",
       "          [-1.5430e-01, -3.7305e-01, -1.9824e-01,  ..., -8.0078e-01,\n",
       "            2.9785e-02, -1.6953e+00],\n",
       "          [ 3.7891e-01, -1.5430e-01, -2.0703e-01,  ...,  6.6016e-01,\n",
       "            4.5508e-01, -7.6562e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-7.8125e-03, -4.6692e-03, -4.5013e-04,  ..., -6.9885e-03,\n",
       "           -8.5449e-03, -8.4839e-03],\n",
       "          [-3.3984e-01, -2.0020e-01,  2.7148e-01,  ...,  4.5898e-02,\n",
       "            8.7109e-01,  5.1562e-01],\n",
       "          [-1.8652e-01,  6.1279e-02,  2.6611e-02,  ...,  2.6562e-01,\n",
       "            1.4453e+00,  6.4844e-01],\n",
       "          ...,\n",
       "          [ 1.2305e-01, -6.4453e-01,  3.7305e-01,  ..., -1.1250e+00,\n",
       "            1.9043e-01, -9.5703e-01],\n",
       "          [-6.0547e-01, -4.5898e-01, -5.6641e-01,  ..., -1.5469e+00,\n",
       "           -3.9795e-02, -3.7891e-01],\n",
       "          [-8.2520e-02,  2.5757e-02,  2.6978e-02,  ..., -7.6172e-01,\n",
       "           -1.1035e-01,  1.9165e-02]],\n",
       "\n",
       "         [[ 1.8311e-02,  2.8992e-03,  1.8066e-02,  ..., -5.8594e-03,\n",
       "            4.4250e-03, -1.2894e-03],\n",
       "          [ 7.6953e-01, -8.1250e-01,  8.8867e-02,  ..., -1.0205e-01,\n",
       "           -3.1250e-01, -2.5781e-01],\n",
       "          [ 8.9062e-01, -1.3047e+00, -2.5977e-01,  ...,  2.4121e-01,\n",
       "           -5.2490e-02, -2.3730e-01],\n",
       "          ...,\n",
       "          [-5.1953e-01,  2.0020e-01, -2.2461e-01,  ..., -2.8906e-01,\n",
       "           -1.0986e-01, -5.9128e-04],\n",
       "          [-2.8906e-01,  5.2344e-01,  6.2988e-02,  ..., -4.4922e-01,\n",
       "           -2.9883e-01, -1.0400e-01],\n",
       "          [ 1.0205e-01, -6.4844e-01, -3.3203e-01,  ..., -5.7422e-01,\n",
       "           -4.4922e-01, -5.0781e-01]],\n",
       "\n",
       "         [[ 3.6001e-05, -1.6556e-03, -3.1128e-03,  ...,  2.0630e-02,\n",
       "           -1.0925e-02,  8.6670e-03],\n",
       "          [ 1.5039e-01, -7.6953e-01, -5.0781e-01,  ..., -3.6914e-01,\n",
       "            2.9883e-01,  6.0547e-01],\n",
       "          [ 4.7070e-01, -3.2422e-01, -9.8145e-02,  ...,  5.9766e-01,\n",
       "            4.6289e-01, -1.4453e-01],\n",
       "          ...,\n",
       "          [-6.4844e-01,  5.3516e-01,  9.1016e-01,  ..., -4.0039e-02,\n",
       "           -4.5166e-02, -8.3203e-01],\n",
       "          [-1.7395e-03, -6.2500e-01,  9.8438e-01,  ..., -5.4297e-01,\n",
       "           -3.1836e-01, -3.4570e-01],\n",
       "          [ 3.7500e-01, -1.7578e-01, -7.6172e-01,  ...,  3.3398e-01,\n",
       "            6.0156e-01, -6.4062e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.2451e-02, -6.0120e-03, -1.5640e-03,  ..., -5.1575e-03,\n",
       "            1.5015e-02,  6.8054e-03],\n",
       "          [ 1.3984e+00,  1.5991e-02,  2.5781e-01,  ..., -4.5654e-02,\n",
       "           -7.9590e-02, -1.0156e-01],\n",
       "          [ 1.4219e+00,  7.3438e-01,  1.3867e-01,  ...,  5.0391e-01,\n",
       "            1.6724e-02, -6.9922e-01],\n",
       "          ...,\n",
       "          [ 8.0859e-01,  4.3164e-01, -1.6699e-01,  ...,  9.3750e-01,\n",
       "            1.7773e-01, -1.2344e+00],\n",
       "          [ 8.3203e-01, -4.6875e-02, -5.1562e-01,  ...,  3.3594e-01,\n",
       "            1.5625e-01, -9.6484e-01],\n",
       "          [-5.1953e-01,  6.9580e-03, -1.3984e+00,  ..., -3.9844e-01,\n",
       "           -2.6953e-01, -1.3125e+00]],\n",
       "\n",
       "         [[ 7.0953e-04, -1.5747e-02, -4.2725e-03,  ..., -4.5776e-03,\n",
       "           -3.7689e-03, -2.4414e-03],\n",
       "          [-1.4160e-01, -5.3516e-01, -5.3711e-02,  ...,  5.7031e-01,\n",
       "            8.2812e-01,  4.4922e-01],\n",
       "          [-4.1016e-02, -7.1484e-01,  2.9297e-02,  ...,  1.4648e-01,\n",
       "            5.5859e-01,  1.0254e-01],\n",
       "          ...,\n",
       "          [-2.3730e-01, -2.5391e-01, -5.3906e-01,  ...,  2.8442e-02,\n",
       "           -4.3945e-01, -3.0078e-01],\n",
       "          [-4.9805e-01, -8.0469e-01,  6.3965e-02,  ...,  1.2500e+00,\n",
       "           -1.1816e-01,  1.1426e-01],\n",
       "          [ 2.8711e-01, -4.1406e-01,  4.6875e-02,  ..., -3.5156e-01,\n",
       "           -2.0801e-01,  6.3281e-01]],\n",
       "\n",
       "         [[-8.3008e-03,  1.7700e-02, -1.0742e-02,  ...,  1.6022e-03,\n",
       "           -1.6235e-02, -5.1575e-03],\n",
       "          [ 4.2969e-01,  1.8359e-01, -3.5938e-01,  ...,  6.9531e-01,\n",
       "           -1.5332e-01,  4.0234e-01],\n",
       "          [ 6.9141e-01, -1.8945e-01, -3.8867e-01,  ...,  9.2773e-02,\n",
       "            1.6895e-01,  4.5898e-02],\n",
       "          ...,\n",
       "          [ 4.2383e-01,  1.9434e-01, -3.1055e-01,  ...,  2.1680e-01,\n",
       "            1.3281e-01, -4.7607e-02],\n",
       "          [ 4.3164e-01, -6.9922e-01,  4.7266e-01,  ...,  4.3164e-01,\n",
       "           -3.5547e-01, -2.4780e-02],\n",
       "          [-5.6396e-02,  3.1445e-01,  2.7734e-01,  ...,  1.0193e-02,\n",
       "           -3.1641e-01,  2.2949e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-2.3556e-04,  1.9653e-02,  2.3682e-02,  ...,  1.9653e-02,\n",
       "           -6.5613e-03, -1.2158e-01],\n",
       "          [ 1.3516e+00, -3.3984e-01, -8.9062e-01,  ...,  2.5391e-01,\n",
       "            2.8906e-01,  9.6484e-01],\n",
       "          [ 6.5625e-01, -2.7930e-01, -1.2500e+00,  ...,  1.1182e-01,\n",
       "            3.7109e-01,  7.7344e-01],\n",
       "          ...,\n",
       "          [ 6.7188e-01, -3.7109e-01,  1.3906e+00,  ..., -4.3945e-01,\n",
       "           -4.3555e-01, -3.9062e-02],\n",
       "          [ 7.2656e-01, -1.0625e+00,  1.4609e+00,  ...,  1.0312e+00,\n",
       "           -7.1094e-01, -2.7539e-01],\n",
       "          [ 8.1641e-01,  1.2451e-01, -1.6211e-01,  ..., -1.1094e+00,\n",
       "            4.0625e-01, -1.1016e+00]],\n",
       "\n",
       "         [[ 2.7771e-03, -8.0566e-03,  2.1851e-02,  ..., -2.2461e-01,\n",
       "           -2.0605e-01, -6.5234e-01],\n",
       "          [-9.6875e-01, -1.0938e+00,  4.8828e-02,  ...,  1.1406e+00,\n",
       "           -1.7031e+00,  1.5391e+00],\n",
       "          [ 7.1484e-01, -1.0156e+00,  1.5259e-02,  ...,  9.7266e-01,\n",
       "           -9.3750e-01,  1.6562e+00],\n",
       "          ...,\n",
       "          [ 5.6250e-01, -3.2422e-01, -9.3750e-01,  ...,  1.0312e+00,\n",
       "           -5.8594e-01, -4.0820e-01],\n",
       "          [-3.5938e-01, -1.7285e-01, -2.6758e-01,  ...,  7.2656e-01,\n",
       "           -7.3438e-01, -9.3750e-01],\n",
       "          [-2.8711e-01,  5.3516e-01, -3.4375e-01,  ...,  9.8828e-01,\n",
       "           -1.0938e+00,  2.9053e-02]],\n",
       "\n",
       "         [[-1.1597e-02,  1.5488e-03,  3.5156e-02,  ..., -3.6328e-01,\n",
       "            4.4678e-02,  1.6113e-01],\n",
       "          [-2.6172e-01,  8.5938e-02, -2.2559e-01,  ..., -8.3203e-01,\n",
       "           -1.8203e+00, -4.2578e-01],\n",
       "          [-4.5508e-01,  9.0234e-01,  6.3477e-02,  ...,  3.9844e-01,\n",
       "           -2.0625e+00, -2.5391e-01],\n",
       "          ...,\n",
       "          [ 5.7031e-01,  9.8633e-02,  2.3828e-01,  ...,  1.7812e+00,\n",
       "           -1.2578e+00, -3.5547e-01],\n",
       "          [ 5.2734e-01,  8.8672e-01,  3.0859e-01,  ...,  1.2031e+00,\n",
       "           -1.5234e-01, -8.3203e-01],\n",
       "          [-4.2188e-01, -5.3516e-01, -1.0547e-01,  ...,  1.8125e+00,\n",
       "           -4.2773e-01,  3.6133e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.8921e-02,  2.1240e-02,  2.1362e-02,  ...,  1.9531e+00,\n",
       "            2.2461e-01, -1.0625e+00],\n",
       "          [-1.4551e-01, -2.9883e-01, -2.0703e-01,  ..., -3.8125e+00,\n",
       "            2.1094e+00,  1.7812e+00],\n",
       "          [-5.8984e-01, -6.7188e-01, -9.2188e-01,  ..., -4.1562e+00,\n",
       "            2.1719e+00,  1.3125e+00],\n",
       "          ...,\n",
       "          [ 4.6680e-01,  4.5703e-01,  3.8672e-01,  ..., -3.9688e+00,\n",
       "            7.9688e-01, -1.1094e+00],\n",
       "          [ 2.3340e-01, -1.6797e-01,  2.2559e-01,  ..., -6.8750e-01,\n",
       "            2.7500e+00,  8.3984e-01],\n",
       "          [ 2.7734e-01,  5.7031e-01,  5.0781e-01,  ..., -5.3750e+00,\n",
       "            2.2812e+00,  5.2246e-02]],\n",
       "\n",
       "         [[-2.7100e-02, -7.4158e-03, -2.2461e-02,  ..., -6.9141e-01,\n",
       "            9.0820e-02,  4.0039e-01],\n",
       "          [-1.3281e-01, -4.4531e-01,  2.8711e-01,  ...,  1.6016e+00,\n",
       "           -1.2734e+00, -8.6328e-01],\n",
       "          [-4.7266e-01, -3.1641e-01,  6.2109e-01,  ...,  2.1562e+00,\n",
       "           -3.7891e-01, -3.4766e-01],\n",
       "          ...,\n",
       "          [ 3.2617e-01, -8.7891e-01, -9.0234e-01,  ..., -3.9648e-01,\n",
       "           -2.0469e+00, -1.1797e+00],\n",
       "          [-2.3438e-01, -1.1719e-01, -1.8457e-01,  ..., -2.9883e-01,\n",
       "           -2.3438e+00, -7.3828e-01],\n",
       "          [-1.6406e-01, -3.9062e-02,  4.4922e-02,  ..., -6.9141e-01,\n",
       "           -7.2656e-01,  1.2578e+00]],\n",
       "\n",
       "         [[-2.4292e-02,  4.6997e-03, -7.7820e-03,  ...,  1.4160e-01,\n",
       "           -8.9844e-02,  7.4219e-02],\n",
       "          [-4.6484e-01,  4.0234e-01, -1.1406e+00,  ...,  4.6484e-01,\n",
       "            8.9844e-02,  1.5000e+00],\n",
       "          [ 6.4453e-02, -4.8242e-01, -8.5938e-01,  ...,  3.8477e-01,\n",
       "            5.5469e-01,  1.6875e+00],\n",
       "          ...,\n",
       "          [-3.5156e-01, -9.7656e-01, -1.1094e+00,  ...,  2.6953e-01,\n",
       "           -1.7344e+00,  1.5781e+00],\n",
       "          [-1.6641e+00, -6.8750e-01, -3.4375e-01,  ..., -9.2578e-01,\n",
       "           -1.0391e+00,  1.1875e+00],\n",
       "          [-8.8672e-01,  5.7422e-01,  1.1953e+00,  ...,  2.7930e-01,\n",
       "           -2.4902e-01,  4.9609e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-5.2185e-03,  2.2583e-02, -1.8692e-03,  ..., -1.6785e-03,\n",
       "           -9.8419e-04,  9.5825e-03],\n",
       "          [-5.4688e-01, -1.6846e-02,  6.8359e-02,  ..., -6.2500e-01,\n",
       "            1.4941e-01,  3.6523e-01],\n",
       "          [-6.5234e-01, -1.4258e-01,  5.9570e-02,  ..., -3.7891e-01,\n",
       "            2.9688e-01,  2.5195e-01],\n",
       "          ...,\n",
       "          [-1.8066e-01,  2.1484e-01, -4.1797e-01,  ..., -9.0234e-01,\n",
       "           -5.5859e-01,  2.8516e-01],\n",
       "          [-1.3379e-01, -1.2012e-01, -3.7500e-01,  ..., -8.6328e-01,\n",
       "           -9.8828e-01, -3.0078e-01],\n",
       "          [ 3.7305e-01, -2.1875e-01,  4.4678e-02,  ..., -1.2988e-01,\n",
       "           -5.8594e-01, -5.3125e-01]],\n",
       "\n",
       "         [[-2.7832e-02, -6.5613e-03, -5.6458e-03,  ...,  1.9897e-02,\n",
       "            9.7275e-04, -1.3580e-03],\n",
       "          [-2.5391e-01, -4.5898e-01, -2.1777e-01,  ..., -8.2422e-01,\n",
       "            1.8848e-01,  2.2754e-01],\n",
       "          [-4.0527e-02, -8.5938e-01, -2.0605e-01,  ..., -7.2266e-01,\n",
       "            2.9883e-01, -9.4727e-02],\n",
       "          ...,\n",
       "          [-2.4512e-01, -2.6367e-01,  2.9297e-01,  ...,  2.3633e-01,\n",
       "            1.0645e-01, -1.0010e-01],\n",
       "          [ 4.4531e-01,  2.1582e-01,  2.5977e-01,  ...,  4.1992e-01,\n",
       "            1.6211e-01, -4.5898e-01],\n",
       "          [-4.4922e-01, -2.7148e-01,  3.5547e-01,  ...,  1.6113e-01,\n",
       "            3.3789e-01,  8.5938e-01]],\n",
       "\n",
       "         [[-2.7008e-03,  1.0071e-02, -1.3580e-03,  ..., -1.4877e-04,\n",
       "           -2.0905e-03,  1.1597e-02],\n",
       "          [-8.7891e-03, -1.1328e+00, -3.2617e-01,  ..., -3.8086e-01,\n",
       "            5.6250e-01,  1.6113e-01],\n",
       "          [-6.8359e-02, -1.0312e+00, -3.8086e-01,  ...,  2.6562e-01,\n",
       "            4.9023e-01,  9.0820e-02],\n",
       "          ...,\n",
       "          [-3.2715e-02,  2.1094e-01, -1.1914e-01,  ...,  5.5859e-01,\n",
       "           -9.1406e-01,  3.9648e-01],\n",
       "          [ 6.3672e-01,  5.5859e-01, -1.5234e-01,  ..., -1.2329e-02,\n",
       "            4.8438e-01, -1.5039e-01],\n",
       "          [-4.8340e-02,  4.1992e-01, -5.3125e-01,  ..., -7.1777e-02,\n",
       "            1.6016e-01,  4.0625e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.2085e-02,  9.3994e-03, -8.6670e-03,  ..., -2.5146e-02,\n",
       "           -1.0071e-03,  5.2185e-03],\n",
       "          [ 3.2227e-01, -5.3516e-01, -5.6641e-01,  ...,  8.7109e-01,\n",
       "            5.5078e-01, -2.2852e-01],\n",
       "          [ 1.2598e-01,  3.2031e-01, -4.5410e-02,  ...,  8.6328e-01,\n",
       "            3.8672e-01,  2.4121e-01],\n",
       "          ...,\n",
       "          [ 4.1504e-02,  3.2227e-01,  1.1768e-01,  ...,  8.6719e-01,\n",
       "           -6.0547e-01, -1.7578e-02],\n",
       "          [ 1.2256e-01,  9.6484e-01, -7.8906e-01,  ...,  6.8359e-01,\n",
       "           -1.0889e-01,  6.4453e-01],\n",
       "          [ 3.5352e-01, -4.8828e-02, -2.4414e-01,  ...,  5.9766e-01,\n",
       "            3.0078e-01,  3.7305e-01]],\n",
       "\n",
       "         [[-1.4832e-02, -2.4109e-03,  3.0518e-03,  ..., -1.5717e-03,\n",
       "            6.5918e-03,  1.8433e-02],\n",
       "          [ 2.2363e-01, -1.2812e+00, -2.5586e-01,  ..., -4.8438e-01,\n",
       "            6.8750e-01, -5.2734e-01],\n",
       "          [ 3.7500e-01, -6.0156e-01,  3.1738e-02,  ..., -2.0508e-01,\n",
       "            6.8750e-01, -1.2207e-01],\n",
       "          ...,\n",
       "          [ 5.3906e-01,  5.0000e-01,  3.5547e-01,  ...,  4.4189e-02,\n",
       "           -5.8984e-01,  2.6562e-01],\n",
       "          [-1.5527e-01, -5.0781e-01, -7.4707e-02,  ..., -4.4531e-01,\n",
       "            9.1797e-02,  1.4355e-01],\n",
       "          [-1.9629e-01, -1.4160e-01, -1.0859e+00,  ..., -2.9883e-01,\n",
       "            4.6680e-01, -3.4570e-01]],\n",
       "\n",
       "         [[ 1.0437e-02, -2.5513e-02, -7.5378e-03,  ...,  6.6223e-03,\n",
       "           -1.9897e-02, -3.8605e-03],\n",
       "          [-2.1191e-01, -2.9297e-01,  4.7852e-01,  ...,  3.0664e-01,\n",
       "            8.9062e-01,  2.9492e-01],\n",
       "          [ 2.1875e-01, -1.5503e-02,  8.2812e-01,  ...,  8.6719e-01,\n",
       "            8.9453e-01, -3.3112e-03],\n",
       "          ...,\n",
       "          [ 1.0703e+00, -8.3008e-02, -5.8594e-01,  ..., -5.5078e-01,\n",
       "            5.5859e-01,  5.1953e-01],\n",
       "          [ 4.6875e-01, -5.6250e-01,  1.5625e+00,  ..., -7.0703e-01,\n",
       "            2.6367e-01,  2.0996e-01],\n",
       "          [ 8.2031e-01,  3.4961e-01,  7.1484e-01,  ...,  3.6914e-01,\n",
       "            6.4844e-01, -1.0391e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 3.3203e-02,  2.2430e-03, -5.4016e-03,  ...,  1.3672e-01,\n",
       "           -2.4902e-01, -1.2656e+00],\n",
       "          [-2.7734e-01,  5.9375e-01, -1.8457e-01,  ...,  7.3242e-02,\n",
       "            3.2969e+00,  8.0000e+00],\n",
       "          [-8.0859e-01,  3.9062e-01, -2.7344e-01,  ...,  7.7344e-01,\n",
       "            3.1562e+00,  7.5625e+00],\n",
       "          ...,\n",
       "          [ 8.5547e-01, -3.1641e-01,  3.8477e-01,  ...,  3.0469e-01,\n",
       "            3.9844e+00,  5.7500e+00],\n",
       "          [ 4.8828e-01, -1.2305e-01,  2.8125e-01,  ..., -7.6953e-01,\n",
       "            2.9688e+00,  4.5625e+00],\n",
       "          [ 2.7344e-01,  5.9570e-02, -1.0742e-01,  ..., -1.3984e+00,\n",
       "            2.9688e+00,  4.4688e+00]],\n",
       "\n",
       "         [[-2.6001e-02,  3.4424e-02,  6.8665e-03,  ..., -2.1582e-01,\n",
       "            4.3213e-02,  3.2617e-01],\n",
       "          [-1.8359e-01, -2.8711e-01,  6.7383e-02,  ...,  1.7676e-01,\n",
       "            4.6875e-01,  2.5781e+00],\n",
       "          [-6.8359e-01,  5.9766e-01,  1.8164e-01,  ...,  2.2656e-01,\n",
       "            2.3145e-01,  1.7422e+00],\n",
       "          ...,\n",
       "          [-6.5625e-01,  3.4961e-01,  1.2695e-01,  ...,  1.2500e+00,\n",
       "            2.8711e-01, -2.5781e-01],\n",
       "          [ 3.0396e-02, -4.8242e-01, -4.8828e-03,  ...,  2.3906e+00,\n",
       "           -2.0215e-01, -2.2266e-01],\n",
       "          [ 9.5312e-01, -2.4023e-01, -4.2480e-02,  ..., -1.8359e-01,\n",
       "            9.9609e-01,  1.6484e+00]],\n",
       "\n",
       "         [[ 7.2327e-03, -1.0529e-03,  1.2207e-02,  ..., -9.0790e-04,\n",
       "           -5.2490e-02, -8.6914e-02],\n",
       "          [ 3.6523e-01, -1.5859e+00,  1.6895e-01,  ..., -8.4229e-03,\n",
       "            1.4648e-01,  1.1953e+00],\n",
       "          [-1.2012e-01, -1.0312e+00,  6.7383e-02,  ..., -3.3984e-01,\n",
       "            2.6172e-01,  2.0625e+00],\n",
       "          ...,\n",
       "          [ 1.4609e+00, -2.3438e-01,  3.0469e-01,  ...,  4.7852e-01,\n",
       "            1.3359e+00,  9.6484e-01],\n",
       "          [ 4.2188e-01, -4.4531e-01, -1.0107e-01,  ...,  2.0625e+00,\n",
       "            1.0469e+00,  3.3398e-01],\n",
       "          [-8.9355e-02, -7.8125e-01, -8.2031e-01,  ...,  3.3594e-01,\n",
       "            9.8828e-01,  8.7891e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.8076e-02, -9.4223e-04,  1.1047e-02,  ..., -8.0078e-02,\n",
       "            1.3086e-01,  9.6191e-02],\n",
       "          [-6.6406e-01, -1.5000e+00,  3.2031e-01,  ..., -1.1250e+00,\n",
       "           -1.1797e+00,  6.7578e-01],\n",
       "          [ 4.5898e-01, -1.4297e+00,  9.6680e-02,  ..., -3.6914e-01,\n",
       "           -1.0625e+00,  9.0625e-01],\n",
       "          ...,\n",
       "          [ 5.4688e-01,  1.0625e+00, -8.0078e-01,  ..., -1.3047e+00,\n",
       "            1.5527e-01,  1.7285e-01],\n",
       "          [ 6.2500e-01,  5.1172e-01, -5.8203e-01,  ..., -7.5781e-01,\n",
       "            1.0547e+00, -6.7188e-01],\n",
       "          [-3.7109e-01, -6.8359e-01,  2.8906e-01,  ..., -1.9375e+00,\n",
       "           -2.0312e-01, -4.7852e-02]],\n",
       "\n",
       "         [[-5.8899e-03,  8.7891e-03, -1.0834e-03,  ...,  1.6309e-01,\n",
       "           -9.0332e-02, -1.0840e-01],\n",
       "          [-1.4531e+00,  9.8633e-02,  1.2354e-01,  ..., -1.4375e+00,\n",
       "            1.0938e+00, -2.4707e-01],\n",
       "          [-1.2891e+00, -4.6143e-02,  2.2754e-01,  ..., -1.4375e+00,\n",
       "            1.5859e+00,  5.2734e-01],\n",
       "          ...,\n",
       "          [ 1.1953e+00, -5.5078e-01,  1.3281e-01,  ..., -2.8711e-01,\n",
       "            6.0547e-01,  1.6406e+00],\n",
       "          [ 7.1484e-01, -1.2109e-01,  1.1641e+00,  ...,  2.2656e+00,\n",
       "           -1.0059e-01,  7.8613e-02],\n",
       "          [-1.0312e+00,  6.6406e-01,  8.1250e-01,  ...,  3.1445e-01,\n",
       "           -2.7222e-02,  1.1406e+00]],\n",
       "\n",
       "         [[ 1.9165e-02,  3.7354e-02, -2.7466e-02,  ..., -1.2024e-02,\n",
       "           -2.8711e-01,  2.0215e-01],\n",
       "          [ 3.9844e-01,  1.1328e+00,  2.7734e-01,  ...,  1.0254e-02,\n",
       "            1.5938e+00, -1.5469e+00],\n",
       "          [-2.3633e-01, -2.1289e-01, -2.6172e-01,  ...,  7.5781e-01,\n",
       "            1.1523e-01, -6.1328e-01],\n",
       "          ...,\n",
       "          [ 1.0889e-01, -1.9922e-01, -3.6719e-01,  ..., -4.0234e-01,\n",
       "           -5.1172e-01, -1.5000e+00],\n",
       "          [ 1.0078e+00, -4.7266e-01, -6.0938e-01,  ...,  8.9062e-01,\n",
       "           -5.7422e-01, -1.6641e+00],\n",
       "          [ 1.9434e-01,  7.4609e-01,  7.4219e-01,  ...,  1.2969e+00,\n",
       "           -9.2578e-01, -2.2969e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-1.0620e-02,  2.0752e-02,  5.3711e-03,  ..., -2.7100e-02,\n",
       "            9.7046e-03, -1.3733e-02],\n",
       "          [-4.2969e-02,  5.3516e-01,  2.0898e-01,  ..., -4.1797e-01,\n",
       "           -1.0547e+00,  2.5195e-01],\n",
       "          [ 4.7852e-01, -2.3633e-01, -2.5195e-01,  ..., -9.4238e-02,\n",
       "           -1.7383e-01,  3.6328e-01],\n",
       "          ...,\n",
       "          [-3.8477e-01, -4.1992e-01,  5.1953e-01,  ...,  1.1953e+00,\n",
       "           -5.9375e-01, -1.2146e-02],\n",
       "          [ 3.8477e-01,  5.0391e-01,  9.5312e-01,  ...,  4.1992e-01,\n",
       "           -1.0234e+00, -5.3516e-01],\n",
       "          [ 3.2422e-01, -2.1777e-01,  1.5625e-01,  ...,  6.0938e-01,\n",
       "           -4.7266e-01,  5.6641e-01]],\n",
       "\n",
       "         [[-1.6724e-02,  2.5787e-03,  1.7452e-04,  ...,  2.0874e-02,\n",
       "           -6.8359e-03,  7.7515e-03],\n",
       "          [ 4.7266e-01,  3.5742e-01, -1.7480e-01,  ...,  2.2754e-01,\n",
       "           -6.5234e-01, -4.2578e-01],\n",
       "          [ 5.5469e-01,  9.8047e-01, -2.8125e-01,  ..., -1.3281e-01,\n",
       "           -4.4141e-01, -1.0000e+00],\n",
       "          ...,\n",
       "          [ 1.6211e-01, -1.6992e-01,  1.6406e-01,  ..., -1.1875e+00,\n",
       "            6.6797e-01,  3.3398e-01],\n",
       "          [ 4.5898e-01, -5.3125e-01,  2.9688e-01,  ..., -9.9121e-02,\n",
       "           -2.7539e-01, -4.4141e-01],\n",
       "          [ 4.3945e-01, -1.7285e-01, -2.6172e-01,  ..., -6.7383e-02,\n",
       "            2.3633e-01,  4.9219e-01]],\n",
       "\n",
       "         [[ 2.0447e-03,  4.1199e-03,  5.5237e-03,  ...,  2.1851e-02,\n",
       "           -8.8501e-03,  1.2024e-02],\n",
       "          [ 1.6016e-01,  6.8750e-01,  3.4961e-01,  ..., -1.3379e-01,\n",
       "           -1.6211e-01,  8.7891e-01],\n",
       "          [ 6.0547e-01,  2.1484e-01,  4.8633e-01,  ..., -1.3672e-01,\n",
       "           -1.7285e-01, -1.5991e-02],\n",
       "          ...,\n",
       "          [ 1.3184e-01, -4.2188e-01, -4.9219e-01,  ..., -2.6953e-01,\n",
       "           -9.6680e-02, -1.1328e-01],\n",
       "          [ 4.4922e-01, -5.8984e-01,  6.3281e-01,  ...,  3.0859e-01,\n",
       "            6.2500e-01,  4.5654e-02],\n",
       "          [ 3.7891e-01, -5.9766e-01, -5.3125e-01,  ..., -7.1875e-01,\n",
       "            4.2578e-01, -9.2578e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 3.5889e-02,  1.7700e-02,  1.2451e-02,  ...,  1.8921e-02,\n",
       "           -1.2329e-02,  2.6367e-02],\n",
       "          [-2.4536e-02, -2.7344e-02,  2.9102e-01,  ...,  3.4912e-02,\n",
       "           -5.4297e-01, -1.8066e-02],\n",
       "          [ 5.1172e-01, -6.7969e-01, -1.4648e-01,  ...,  3.4180e-01,\n",
       "           -5.6641e-01,  1.8457e-01],\n",
       "          ...,\n",
       "          [ 5.4297e-01,  4.2578e-01, -6.4941e-02,  ...,  1.0391e+00,\n",
       "           -8.6719e-01,  4.1406e-01],\n",
       "          [-5.1953e-01, -4.9072e-02, -1.2422e+00,  ...,  1.2012e-01,\n",
       "           -1.3125e+00,  1.4219e+00],\n",
       "          [ 1.6309e-01,  2.2095e-02, -4.1406e-01,  ...,  4.4531e-01,\n",
       "           -1.0449e-01,  6.2891e-01]],\n",
       "\n",
       "         [[-3.5400e-02,  2.1240e-02, -1.9897e-02,  ..., -1.4099e-02,\n",
       "           -3.7354e-02,  1.5259e-02],\n",
       "          [-5.8203e-01,  1.6406e-01, -8.8281e-01,  ..., -5.7812e-01,\n",
       "            2.1094e-01, -6.6797e-01],\n",
       "          [-6.2109e-01,  4.0430e-01, -1.3516e+00,  ..., -5.0391e-01,\n",
       "            1.0010e-01, -1.0625e+00],\n",
       "          ...,\n",
       "          [ 2.6562e-01,  1.4844e-01,  2.7148e-01,  ...,  3.8672e-01,\n",
       "            6.9531e-01,  3.3789e-01],\n",
       "          [-3.0664e-01, -3.9844e-01,  1.1816e-01,  ...,  7.2656e-01,\n",
       "            1.3574e-01,  3.2617e-01],\n",
       "          [-1.0889e-01,  3.5938e-01,  4.1260e-02,  ...,  1.0596e-01,\n",
       "            5.3516e-01,  2.5586e-01]],\n",
       "\n",
       "         [[-7.8735e-03, -1.3046e-03, -1.8799e-02,  ...,  1.9226e-03,\n",
       "           -5.3711e-03, -1.1719e-02],\n",
       "          [-5.1758e-02,  7.6172e-01,  6.0938e-01,  ...,  5.7031e-01,\n",
       "           -1.6504e-01, -4.4531e-01],\n",
       "          [ 2.5781e-01, -3.3984e-01,  9.2578e-01,  ...,  1.4688e+00,\n",
       "            8.2031e-02, -2.7539e-01],\n",
       "          ...,\n",
       "          [ 1.0078e+00, -4.8828e-01,  5.0000e-01,  ...,  8.3984e-01,\n",
       "           -2.3438e-01, -9.1797e-02],\n",
       "          [ 1.0078e+00, -1.2656e+00,  8.1250e-01,  ...,  9.5703e-01,\n",
       "           -1.3984e+00, -3.0664e-01],\n",
       "          [ 4.7461e-01,  1.6797e+00,  3.8477e-01,  ...,  7.1777e-02,\n",
       "           -7.4609e-01, -1.5547e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-4.8523e-03, -6.0547e-02,  3.6163e-03,  ..., -6.6895e-02,\n",
       "            1.9043e-01, -1.1816e-01],\n",
       "          [ 1.3086e-01,  2.3633e-01,  7.9102e-02,  ..., -3.5742e-01,\n",
       "           -1.5527e-01, -1.2012e-01],\n",
       "          [-8.3984e-01,  9.7656e-01,  2.9492e-01,  ..., -1.1133e-01,\n",
       "            5.6641e-01, -2.1484e-01],\n",
       "          ...,\n",
       "          [ 1.4453e-01, -4.7070e-01, -4.3555e-01,  ..., -1.0693e-01,\n",
       "           -3.2422e-01,  2.2188e+00],\n",
       "          [-5.1758e-02,  2.2188e+00, -1.4844e+00,  ...,  1.1182e-01,\n",
       "            5.0391e-01,  5.3516e-01],\n",
       "          [-4.6484e-01,  6.4844e-01, -1.0938e+00,  ...,  3.3984e-01,\n",
       "            2.8320e-01,  3.3008e-01]],\n",
       "\n",
       "         [[-1.0729e-04, -2.3682e-02, -1.6113e-02,  ..., -1.7871e-01,\n",
       "            7.4707e-02, -2.4707e-01],\n",
       "          [ 9.9609e-02, -2.5781e-01, -4.1016e-02,  ..., -2.7539e-01,\n",
       "            3.3398e-01, -6.6833e-03],\n",
       "          [-8.3984e-02, -2.0410e-01, -7.4609e-01,  ...,  5.4688e-01,\n",
       "            8.0078e-01, -3.0273e-01],\n",
       "          ...,\n",
       "          [-7.0312e-01, -3.2812e-01, -4.2383e-01,  ...,  4.1211e-01,\n",
       "            8.5449e-02,  8.4766e-01],\n",
       "          [-9.5312e-01, -4.7266e-01,  2.9297e-03,  ...,  1.7188e-01,\n",
       "           -1.3672e+00,  4.7070e-01],\n",
       "          [-6.8750e-01, -1.5137e-01,  7.8516e-01,  ...,  1.1797e+00,\n",
       "            8.0078e-01,  1.8672e+00]],\n",
       "\n",
       "         [[-2.0294e-03,  1.8433e-02, -1.1414e-02,  ...,  7.2754e-02,\n",
       "            2.3730e-01, -8.1055e-02],\n",
       "          [ 4.8828e-01, -2.9297e-03,  3.1836e-01,  ...,  2.6406e+00,\n",
       "            7.8125e-01,  3.8672e-01],\n",
       "          [ 2.8516e-01,  7.9297e-01,  1.8652e-01,  ...,  1.8281e+00,\n",
       "            4.9072e-02,  1.6484e+00],\n",
       "          ...,\n",
       "          [-1.7578e+00, -5.7812e-01,  5.0781e-02,  ...,  1.0986e-01,\n",
       "            6.8359e-01, -1.8359e-01],\n",
       "          [-1.8555e-01,  1.5625e-01,  6.3965e-02,  ...,  7.0703e-01,\n",
       "            3.9844e+00, -1.3594e+00],\n",
       "          [-7.1484e-01, -2.5391e-01, -2.3926e-02,  ...,  6.0938e-01,\n",
       "            2.0000e+00, -1.4258e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.9775e-02, -2.0996e-02, -1.0559e-02,  ...,  1.3574e-01,\n",
       "            1.3086e-01, -6.2500e-01],\n",
       "          [-3.8281e-01,  1.6016e-01,  4.4141e-01,  ..., -1.2344e+00,\n",
       "           -2.0625e+00,  4.1016e-01],\n",
       "          [ 1.8848e-01,  2.8516e-01,  4.1992e-01,  ..., -1.0938e+00,\n",
       "           -2.2344e+00,  1.3516e+00],\n",
       "          ...,\n",
       "          [ 3.8477e-01, -2.5586e-01, -3.4180e-01,  ..., -8.6328e-01,\n",
       "           -4.2578e-01,  6.9531e-01],\n",
       "          [-1.3281e-01, -1.5625e-01,  1.6895e-01,  ..., -1.7578e+00,\n",
       "            5.2734e-01,  7.5391e-01],\n",
       "          [ 2.1680e-01, -2.1582e-01, -1.2500e-01,  ..., -1.1797e+00,\n",
       "           -7.8516e-01,  5.3125e-01]],\n",
       "\n",
       "         [[-2.0630e-02,  3.0029e-02, -3.5889e-02,  ...,  2.5938e+00,\n",
       "           -3.0469e-01, -4.8242e-01],\n",
       "          [ 1.6016e-01, -1.6797e+00,  1.5469e+00,  ..., -8.3125e+00,\n",
       "            9.5703e-01, -1.6328e+00],\n",
       "          [ 2.3438e+00, -1.7578e+00,  4.3359e-01,  ..., -9.8125e+00,\n",
       "           -8.0859e-01, -2.6562e+00],\n",
       "          ...,\n",
       "          [-1.1797e+00,  1.0312e+00, -1.0625e+00,  ..., -1.0250e+01,\n",
       "           -9.8047e-01, -3.8906e+00],\n",
       "          [-7.2656e-01,  1.7578e+00,  5.3906e-01,  ..., -9.5625e+00,\n",
       "            1.4062e+00, -3.0938e+00],\n",
       "          [-4.8828e-01,  7.8906e-01, -4.6094e-01,  ..., -9.5000e+00,\n",
       "           -4.1016e-01, -5.0000e-01]],\n",
       "\n",
       "         [[-2.9144e-03, -1.6235e-02,  3.0975e-03,  ..., -1.9141e+00,\n",
       "           -2.0386e-02,  4.7852e-01],\n",
       "          [-3.3398e-01, -1.9238e-01,  1.1914e-01,  ...,  9.3125e+00,\n",
       "           -5.3125e-01,  1.4844e-01],\n",
       "          [-1.3672e-02,  5.7031e-01,  3.9648e-01,  ...,  1.0250e+01,\n",
       "           -5.1562e-01, -6.3672e-01],\n",
       "          ...,\n",
       "          [ 3.7109e-01,  1.4297e+00,  7.1094e-01,  ...,  1.0312e+01,\n",
       "            1.0596e-01, -2.0156e+00],\n",
       "          [-8.6914e-02, -4.2969e-01,  2.9297e-01,  ...,  1.0938e+01,\n",
       "           -1.1816e-01, -1.1641e+00],\n",
       "          [ 5.6641e-02,  3.4180e-02, -4.9219e-01,  ...,  1.0500e+01,\n",
       "           -1.3516e+00, -2.1875e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-2.0874e-02, -8.1787e-03,  3.3203e-02,  ..., -3.9795e-02,\n",
       "           -1.3245e-02,  3.2471e-02],\n",
       "          [-2.2656e-01,  3.6133e-02,  4.2578e-01,  ...,  4.3945e-01,\n",
       "            2.9102e-01,  2.4219e-01],\n",
       "          [-5.1562e-01,  1.3672e-01, -4.3164e-01,  ...,  1.0498e-01,\n",
       "           -4.0625e-01,  5.8203e-01],\n",
       "          ...,\n",
       "          [-2.8516e-01,  7.2754e-02, -8.1641e-01,  ...,  6.4941e-02,\n",
       "            1.0469e+00, -9.4922e-01],\n",
       "          [ 3.3398e-01,  1.8457e-01, -7.8516e-01,  ...,  3.0859e-01,\n",
       "           -3.7109e-01, -8.2031e-01],\n",
       "          [ 1.3359e+00, -1.3184e-01, -9.1406e-01,  ..., -3.4570e-01,\n",
       "           -5.9375e-01, -1.8066e-01]],\n",
       "\n",
       "         [[-1.4404e-02,  1.5625e-02, -1.7944e-02,  ..., -1.0925e-02,\n",
       "           -4.1199e-03,  1.6602e-02],\n",
       "          [ 5.7812e-01, -8.2812e-01, -4.9805e-01,  ...,  6.9922e-01,\n",
       "           -6.1719e-01, -6.3281e-01],\n",
       "          [ 1.9531e-01, -1.0547e+00, -5.4688e-01,  ...,  5.6641e-01,\n",
       "           -8.5938e-01, -4.1602e-01],\n",
       "          ...,\n",
       "          [-1.1641e+00,  3.2471e-02,  1.0449e-01,  ...,  2.5391e-01,\n",
       "           -4.9805e-01,  9.8633e-02],\n",
       "          [-6.9922e-01, -2.6562e-01, -1.3047e+00,  ..., -7.9688e-01,\n",
       "           -1.2812e+00,  3.4570e-01],\n",
       "          [-1.0010e-02,  2.8320e-01, -6.2988e-02,  ...,  1.6211e-01,\n",
       "            2.7539e-01,  5.1562e-01]],\n",
       "\n",
       "         [[ 1.4709e-02,  1.8433e-02, -2.8931e-02,  ...,  1.2878e-02,\n",
       "           -5.0964e-03,  1.5030e-03],\n",
       "          [-8.6426e-02,  3.8574e-02, -1.5000e+00,  ..., -4.6875e-01,\n",
       "            2.7734e-01, -6.0547e-01],\n",
       "          [-1.1816e-01,  2.5977e-01, -1.2266e+00,  ..., -5.1562e-01,\n",
       "            5.3906e-01, -2.9688e-01],\n",
       "          ...,\n",
       "          [-6.6016e-01,  2.5586e-01,  1.3867e-01,  ..., -6.8359e-01,\n",
       "           -9.0625e-01,  5.8984e-01],\n",
       "          [ 1.8125e+00, -1.8750e+00, -2.7500e+00,  ...,  6.4453e-01,\n",
       "            1.0312e+00,  8.7891e-01],\n",
       "          [ 8.3594e-01, -3.5352e-01, -5.1562e-01,  ..., -2.9102e-01,\n",
       "            5.2344e-01, -5.1270e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 2.6367e-02,  1.7643e-04, -4.6692e-03,  ...,  4.0283e-02,\n",
       "           -2.6855e-03,  6.9275e-03],\n",
       "          [ 2.5391e-01, -5.6250e-01, -5.1562e-01,  ...,  5.3125e-01,\n",
       "            1.4893e-02,  2.6367e-01],\n",
       "          [ 1.6895e-01, -1.1719e-01, -5.5078e-01,  ...,  8.4375e-01,\n",
       "           -3.1836e-01,  1.1797e+00],\n",
       "          ...,\n",
       "          [ 4.8438e-01,  2.1094e-01,  3.4668e-02,  ..., -1.1250e+00,\n",
       "            5.0781e-01, -1.4160e-01],\n",
       "          [ 2.7539e-01, -1.0303e-01, -5.2344e-01,  ...,  3.9648e-01,\n",
       "            2.5586e-01,  2.7539e-01],\n",
       "          [ 4.3945e-01,  4.4556e-03,  2.1875e-01,  ..., -1.9824e-01,\n",
       "           -1.7871e-01, -3.7695e-01]],\n",
       "\n",
       "         [[-2.0508e-02,  1.2695e-02, -2.8198e-02,  ...,  2.8610e-04,\n",
       "           -5.9509e-03, -1.1719e-02],\n",
       "          [-5.0293e-02,  4.4727e-01, -4.5312e-01,  ..., -1.8188e-02,\n",
       "            4.2480e-02, -6.5613e-03],\n",
       "          [ 9.9609e-02,  7.3828e-01,  4.0234e-01,  ..., -2.6562e-01,\n",
       "            9.2163e-03, -1.8652e-01],\n",
       "          ...,\n",
       "          [ 8.0078e-01,  5.2344e-01, -7.6562e-01,  ..., -2.4121e-01,\n",
       "            5.1953e-01,  3.1641e-01],\n",
       "          [ 9.4238e-02,  3.3398e-01, -8.3203e-01,  ..., -2.5586e-01,\n",
       "            3.0078e-01, -6.9531e-01],\n",
       "          [ 2.8906e-01, -5.4688e-01, -2.5195e-01,  ..., -4.9414e-01,\n",
       "            1.4062e-01, -1.8359e-01]],\n",
       "\n",
       "         [[-6.0272e-04,  4.9438e-03, -2.0996e-02,  ...,  3.9062e-03,\n",
       "            1.2390e-02, -1.0193e-02],\n",
       "          [-3.7695e-01, -3.3789e-01,  5.9766e-01,  ..., -9.9609e-01,\n",
       "            2.6367e-01, -6.1035e-02],\n",
       "          [-5.4688e-01, -1.9238e-01,  7.5391e-01,  ..., -9.1797e-01,\n",
       "           -4.2773e-01, -1.6504e-01],\n",
       "          ...,\n",
       "          [-2.6758e-01,  4.9805e-01,  8.3496e-02,  ..., -6.5234e-01,\n",
       "           -4.6387e-02, -2.4219e-01],\n",
       "          [-1.9688e+00, -9.8145e-02, -9.6094e-01,  ...,  3.7305e-01,\n",
       "            3.7500e-01,  2.3560e-02],\n",
       "          [-5.3906e-01,  3.9307e-02, -1.0938e-01,  ...,  4.1504e-02,\n",
       "           -6.5625e-01,  1.5234e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-3.5248e-03, -1.5991e-02, -1.3184e-02,  ...,  7.7637e-02,\n",
       "           -8.5938e-02, -5.9814e-02],\n",
       "          [-8.3594e-01, -2.5391e-02, -6.5234e-01,  ...,  4.3701e-02,\n",
       "           -9.0625e-01, -1.4844e-01],\n",
       "          [-3.4766e-01, -3.1445e-01,  1.0703e+00,  ...,  1.3477e-01,\n",
       "           -2.9297e-01, -1.0859e+00],\n",
       "          ...,\n",
       "          [ 1.8828e+00, -3.1641e-01,  7.8906e-01,  ..., -3.8672e-01,\n",
       "            2.1875e+00, -1.6016e+00],\n",
       "          [-1.5332e-01,  1.3672e+00, -5.9766e-01,  ..., -1.2109e+00,\n",
       "            1.3906e+00,  7.7637e-02],\n",
       "          [-9.3750e-01,  3.5352e-01,  3.2031e-01,  ...,  8.3203e-01,\n",
       "            4.5312e-01, -9.5703e-01]],\n",
       "\n",
       "         [[ 8.1787e-03,  9.5825e-03,  4.1389e-04,  ..., -9.6436e-03,\n",
       "            2.2461e-01, -2.3340e-01],\n",
       "          [ 3.3594e-01, -1.1816e-01, -4.4531e-01,  ...,  1.7031e+00,\n",
       "           -2.3906e+00,  1.4531e+00],\n",
       "          [ 5.4297e-01, -5.8594e-01, -2.4805e-01,  ...,  2.1562e+00,\n",
       "           -1.8359e+00,  1.1875e+00],\n",
       "          ...,\n",
       "          [-1.8359e+00,  9.1797e-02, -7.0703e-01,  ...,  1.5625e+00,\n",
       "           -9.8438e-01,  4.5117e-01],\n",
       "          [-1.3184e-02, -4.9023e-01,  2.5391e-02,  ...,  1.2578e+00,\n",
       "            1.3281e+00,  1.2500e+00],\n",
       "          [ 7.1289e-02, -5.2734e-01, -3.3594e-01,  ...,  2.1562e+00,\n",
       "           -4.1016e-01,  6.4844e-01]],\n",
       "\n",
       "         [[-2.1820e-03, -6.9275e-03, -1.0193e-02,  ...,  5.5420e-02,\n",
       "            2.6758e-01, -8.1250e-01],\n",
       "          [-1.3203e+00,  4.4922e-02, -1.8066e-01,  ...,  2.4512e-01,\n",
       "           -1.9727e-01,  2.0469e+00],\n",
       "          [ 2.8320e-01, -1.6250e+00, -4.6289e-01,  ...,  6.8359e-01,\n",
       "           -8.0469e-01,  2.8750e+00],\n",
       "          ...,\n",
       "          [ 4.4141e-01,  1.5000e+00, -8.7109e-01,  ...,  1.7734e+00,\n",
       "           -1.5469e+00,  1.0703e+00],\n",
       "          [-1.3379e-01,  1.4766e+00, -1.3359e+00,  ...,  7.6953e-01,\n",
       "           -1.2109e+00,  2.4219e+00],\n",
       "          [-1.7656e+00,  1.0078e+00, -1.6406e-01,  ..., -1.7734e+00,\n",
       "           -1.2656e+00,  1.8672e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.9165e-02,  1.0742e-02, -1.8921e-02,  ...,  2.4512e-01,\n",
       "            1.0254e-01, -6.6406e-02],\n",
       "          [ 2.7344e-02, -4.7852e-01,  5.2734e-01,  ...,  2.2031e+00,\n",
       "           -2.5781e-01, -5.6641e-01],\n",
       "          [ 8.9844e-02, -3.4766e-01,  2.0312e-01,  ..., -3.2812e-01,\n",
       "           -7.8906e-01, -3.1250e-01],\n",
       "          ...,\n",
       "          [-1.1230e-02,  2.8125e-01, -3.4912e-02,  ..., -1.1797e+00,\n",
       "           -8.0469e-01,  2.3281e+00],\n",
       "          [ 8.6914e-02, -2.2852e-01,  6.8359e-02,  ...,  4.3945e-02,\n",
       "           -1.7656e+00, -2.3145e-01],\n",
       "          [-4.8828e-02, -9.0820e-02,  4.0430e-01,  ..., -8.9844e-02,\n",
       "           -1.1094e+00,  8.3984e-01]],\n",
       "\n",
       "         [[-1.0071e-02, -5.5847e-03, -1.8311e-02,  ...,  2.4048e-02,\n",
       "            2.9102e-01,  5.0537e-02],\n",
       "          [ 5.8984e-01, -1.7578e-02, -2.5000e-01,  ...,  5.8984e-01,\n",
       "           -5.1562e-01, -7.0703e-01],\n",
       "          [ 2.5391e-01, -5.3906e-01, -5.2490e-03,  ...,  1.0547e+00,\n",
       "           -2.7344e-01, -4.3555e-01],\n",
       "          ...,\n",
       "          [ 6.5625e-01,  3.2812e-01,  3.0469e-01,  ...,  5.4688e-01,\n",
       "           -2.4062e+00, -2.7734e-01],\n",
       "          [ 1.3281e+00, -6.7969e-01, -1.8164e-01,  ...,  1.2891e+00,\n",
       "           -1.4609e+00,  3.0859e-01],\n",
       "          [ 7.0312e-02, -1.1328e+00,  3.5156e-02,  ...,  2.6758e-01,\n",
       "           -1.4922e+00, -2.2888e-03]],\n",
       "\n",
       "         [[ 8.4229e-03, -9.2163e-03,  2.2461e-02,  ...,  1.4062e-01,\n",
       "            3.3203e-01, -1.8438e+00],\n",
       "          [ 8.8672e-01, -2.9688e-01, -1.4258e-01,  ...,  6.2109e-01,\n",
       "           -1.5859e+00,  6.8750e+00],\n",
       "          [-6.7383e-02,  2.4121e-01, -2.0312e-01,  ...,  1.0469e+00,\n",
       "           -1.3906e+00,  7.8750e+00],\n",
       "          ...,\n",
       "          [ 3.3984e-01, -1.1641e+00,  8.7109e-01,  ..., -4.3750e-01,\n",
       "           -1.4844e+00,  7.8750e+00],\n",
       "          [ 8.2422e-01, -8.4375e-01, -6.9531e-01,  ..., -2.8125e-01,\n",
       "           -1.5000e+00,  8.0625e+00],\n",
       "          [ 2.9688e-01, -6.1328e-01, -8.6328e-01,  ...,  1.8203e+00,\n",
       "           -1.4297e+00,  7.0938e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 9.3994e-03,  2.8931e-02,  1.8433e-02,  ..., -3.2715e-02,\n",
       "           -1.2131e-03, -8.4229e-03],\n",
       "          [-1.6699e-01, -5.2344e-01,  8.9062e-01,  ...,  6.3672e-01,\n",
       "           -1.2188e+00, -6.9141e-01],\n",
       "          [-1.8750e-01, -2.3535e-01,  3.9648e-01,  ...,  6.7969e-01,\n",
       "           -6.7578e-01, -8.3203e-01],\n",
       "          ...,\n",
       "          [ 3.6133e-01,  9.3750e-02, -4.7607e-03,  ...,  1.1215e-03,\n",
       "            1.4941e-01, -6.3281e-01],\n",
       "          [-6.3281e-01,  3.2812e-01, -5.2734e-02,  ..., -6.9531e-01,\n",
       "           -6.4844e-01, -6.0547e-01],\n",
       "          [ 2.2095e-02, -3.6914e-01,  4.9414e-01,  ..., -7.7344e-01,\n",
       "           -1.7480e-01,  9.2773e-03]],\n",
       "\n",
       "         [[ 6.1340e-03, -1.6556e-03, -2.5940e-03,  ..., -1.4526e-02,\n",
       "            4.7302e-03,  5.1270e-03],\n",
       "          [ 6.7188e-01,  1.1719e+00,  1.7676e-01,  ..., -4.3164e-01,\n",
       "           -6.2891e-01,  9.4531e-01],\n",
       "          [ 1.3672e+00,  6.5234e-01,  2.6953e-01,  ..., -5.1562e-01,\n",
       "           -2.8320e-01,  9.8828e-01],\n",
       "          ...,\n",
       "          [ 3.6914e-01, -3.5938e-01,  4.5703e-01,  ..., -1.9453e+00,\n",
       "           -1.0107e-01,  1.3574e-01],\n",
       "          [ 3.7695e-01,  1.1133e-01, -5.5664e-02,  ..., -7.4219e-01,\n",
       "            5.5859e-01,  9.8047e-01],\n",
       "          [ 9.4141e-01,  2.5781e-01,  6.1523e-02,  ..., -5.1562e-01,\n",
       "           -2.9419e-02, -1.4062e-01]],\n",
       "\n",
       "         [[ 1.1963e-02,  1.5991e-02,  1.8066e-02,  ..., -4.7913e-03,\n",
       "            1.3367e-02, -2.5024e-02],\n",
       "          [-9.9609e-02,  1.0469e+00,  2.9883e-01,  ..., -4.1211e-01,\n",
       "            4.3750e-01,  3.5156e-01],\n",
       "          [-8.3984e-01,  1.1484e+00,  1.9336e-01,  ..., -3.8281e-01,\n",
       "            1.5723e-01,  1.3359e+00],\n",
       "          ...,\n",
       "          [ 6.6797e-01,  1.5234e+00,  7.7734e-01,  ..., -8.4766e-01,\n",
       "            1.1719e+00,  3.4766e-01],\n",
       "          [ 8.2812e-01,  5.5078e-01, -1.3477e-01,  ..., -1.4141e+00,\n",
       "            5.8203e-01,  5.4688e-01],\n",
       "          [ 5.8105e-02,  2.3242e-01,  7.3828e-01,  ..., -1.1172e+00,\n",
       "           -1.4453e-01, -1.1641e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.0742e-02,  3.9062e-03,  2.1577e-05,  ..., -4.4250e-03,\n",
       "           -8.7280e-03, -8.4229e-03],\n",
       "          [ 4.9023e-01, -2.6758e-01,  7.9297e-01,  ..., -2.5586e-01,\n",
       "           -2.0703e-01,  3.9453e-01],\n",
       "          [-9.5215e-02,  1.5991e-02,  8.8672e-01,  ...,  1.8750e-01,\n",
       "           -1.9043e-02,  7.2266e-01],\n",
       "          ...,\n",
       "          [ 1.4609e+00, -3.7305e-01, -6.8848e-02,  ...,  1.4062e-01,\n",
       "            8.7891e-02,  3.7109e-02],\n",
       "          [-9.1797e-01, -5.2246e-02, -5.4688e-01,  ...,  1.6328e+00,\n",
       "            1.0312e+00,  1.4160e-01],\n",
       "          [-8.4766e-01,  1.0840e-01,  8.0078e-01,  ...,  1.0205e-01,\n",
       "           -6.8359e-01, -9.2188e-01]],\n",
       "\n",
       "         [[-3.1494e-02,  4.3335e-03, -2.3438e-02,  ..., -2.5024e-02,\n",
       "            5.3711e-03, -5.5420e-02],\n",
       "          [ 1.2354e-01, -1.4844e-01,  1.0234e+00,  ...,  2.6758e-01,\n",
       "            8.9844e-01, -1.2578e+00],\n",
       "          [ 5.3906e-01, -5.7031e-01,  7.8906e-01,  ...,  3.3594e-01,\n",
       "            3.8086e-01, -1.1797e+00],\n",
       "          ...,\n",
       "          [ 5.5859e-01, -1.2969e+00,  2.1680e-01,  ...,  2.4023e-01,\n",
       "           -2.4048e-02,  3.4961e-01],\n",
       "          [-3.0273e-01, -5.7812e-01,  1.0234e+00,  ..., -7.5391e-01,\n",
       "           -1.5391e+00, -6.8750e-01],\n",
       "          [ 2.3535e-01, -1.4453e+00,  9.9609e-01,  ...,  8.5938e-01,\n",
       "            1.4941e-01,  3.5156e-01]],\n",
       "\n",
       "         [[-2.1057e-03, -9.8267e-03,  2.8687e-02,  ...,  7.5340e-05,\n",
       "           -3.4668e-02,  1.8311e-02],\n",
       "          [-2.5781e-01, -1.5918e-01, -1.4844e-01,  ..., -5.3516e-01,\n",
       "           -5.2344e-01, -3.4961e-01],\n",
       "          [ 3.0078e-01, -8.2520e-02, -4.0039e-01,  ..., -1.3281e-01,\n",
       "           -1.1641e+00, -3.8818e-02],\n",
       "          ...,\n",
       "          [ 1.4609e+00, -1.4609e+00,  2.2827e-02,  ...,  3.5938e-01,\n",
       "            1.4258e-01,  1.5547e+00],\n",
       "          [ 7.8125e-01, -4.1016e-01,  5.1514e-02,  ..., -2.8320e-01,\n",
       "            1.4219e+00,  9.3750e-01],\n",
       "          [ 1.1406e+00, -1.6211e-01, -8.2031e-01,  ..., -5.1172e-01,\n",
       "           -1.0547e+00,  1.1797e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 1.2207e-02,  1.3916e-02,  3.8757e-03,  ...,  1.1328e-01,\n",
       "            4.5312e-01, -1.9238e-01],\n",
       "          [-5.4688e-01,  8.5156e-01,  6.0156e-01,  ..., -2.2888e-03,\n",
       "           -1.1953e+00, -1.0986e-01],\n",
       "          [ 2.1680e-01,  2.8711e-01,  4.8633e-01,  ...,  4.8633e-01,\n",
       "            3.0884e-02,  7.5781e-01],\n",
       "          ...,\n",
       "          [-8.5547e-01, -3.7109e-01, -7.5781e-01,  ..., -1.2988e-01,\n",
       "           -1.7344e+00,  9.4531e-01],\n",
       "          [-7.8125e-02, -6.5430e-02, -2.0410e-01,  ...,  6.4844e-01,\n",
       "           -1.1094e+00,  8.9453e-01],\n",
       "          [ 4.2188e-01, -4.7119e-02, -2.4414e-02,  ..., -3.9062e-01,\n",
       "           -2.2812e+00,  5.1953e-01]],\n",
       "\n",
       "         [[-2.2583e-02,  5.2795e-03,  5.6152e-03,  ..., -9.7656e-02,\n",
       "            7.4707e-02, -9.0332e-02],\n",
       "          [ 1.5156e+00,  4.4531e-01,  5.7812e-01,  ..., -9.7656e-01,\n",
       "            1.0547e+00, -1.0781e+00],\n",
       "          [ 7.9688e-01,  2.9102e-01,  3.9648e-01,  ..., -5.1953e-01,\n",
       "            9.4922e-01, -1.4062e+00],\n",
       "          ...,\n",
       "          [-1.2031e+00,  4.7461e-01,  3.1641e-01,  ..., -1.3047e+00,\n",
       "            6.4844e-01, -1.0312e+00],\n",
       "          [ 3.3594e-01, -3.7305e-01,  1.0938e-01,  ...,  8.2422e-01,\n",
       "           -1.7500e+00, -1.9219e+00],\n",
       "          [ 7.8125e-01,  2.4023e-01, -1.4160e-02,  ..., -2.6406e+00,\n",
       "            8.4473e-02, -7.2656e-01]],\n",
       "\n",
       "         [[-2.9602e-03, -1.5564e-02,  3.3951e-04,  ...,  4.2578e-01,\n",
       "            2.7148e-01, -1.9336e-01],\n",
       "          [-1.0000e+00,  2.3438e-01,  5.7812e-01,  ...,  1.0449e-01,\n",
       "           -6.1328e-01,  1.3438e+00],\n",
       "          [-4.6094e-01,  7.5391e-01, -4.2773e-01,  ..., -1.0078e+00,\n",
       "           -1.6016e+00,  7.1875e-01],\n",
       "          ...,\n",
       "          [-4.1211e-01, -6.5625e-01,  1.4941e-01,  ...,  4.6875e-01,\n",
       "           -1.6602e-02,  1.0254e-01],\n",
       "          [ 1.3516e+00,  4.1406e-01, -1.1562e+00,  ..., -1.5039e-01,\n",
       "           -1.6406e+00,  8.7891e-01],\n",
       "          [-6.2012e-02, -2.0312e-01, -7.3242e-02,  ..., -5.1953e-01,\n",
       "           -1.2969e+00, -1.6699e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.2512e-02,  1.8066e-02, -5.9204e-03,  ..., -1.5312e+00,\n",
       "            1.8359e-01, -5.0000e-01],\n",
       "          [-1.0000e+00, -4.1992e-01, -2.3438e-01,  ...,  3.5469e+00,\n",
       "           -1.7344e+00, -2.5781e-01],\n",
       "          [-1.2031e+00, -2.7344e-01, -9.9609e-01,  ...,  4.6250e+00,\n",
       "           -1.1172e+00,  1.9922e-01],\n",
       "          ...,\n",
       "          [ 5.9375e-01,  7.9297e-01,  4.4727e-01,  ...,  4.0000e+00,\n",
       "           -9.2578e-01, -1.5469e+00],\n",
       "          [ 1.2422e+00, -2.9102e-01,  2.8320e-01,  ...,  4.7188e+00,\n",
       "           -7.0312e-01, -1.7344e+00],\n",
       "          [ 6.4844e-01,  6.0156e-01,  1.0234e+00,  ...,  4.1562e+00,\n",
       "           -1.3438e+00, -1.8125e+00]],\n",
       "\n",
       "         [[ 1.6602e-02, -1.5625e-02, -1.5015e-02,  ..., -1.7773e-01,\n",
       "            6.1035e-02, -2.1973e-03],\n",
       "          [ 2.9297e-01, -4.8047e-01, -3.9062e-03,  ..., -2.4121e-01,\n",
       "           -1.4766e+00,  1.5938e+00],\n",
       "          [ 5.7812e-01,  1.6895e-01, -4.5312e-01,  ...,  3.0859e-01,\n",
       "           -7.5391e-01,  1.8828e+00],\n",
       "          ...,\n",
       "          [-1.5938e+00,  4.3945e-01, -3.1250e-02,  ...,  6.9922e-01,\n",
       "           -2.8125e-01,  1.0469e+00],\n",
       "          [-2.9062e+00, -1.3750e+00, -4.5898e-01,  ...,  9.6875e-01,\n",
       "            9.5312e-01,  1.0859e+00],\n",
       "          [-9.6484e-01, -1.4844e-01, -4.7656e-01,  ..., -5.0391e-01,\n",
       "           -1.7969e-01, -4.3750e-01]],\n",
       "\n",
       "         [[-2.5269e-02,  1.8799e-02,  8.1062e-05,  ..., -3.0273e-01,\n",
       "           -8.3984e-02, -1.5137e-01],\n",
       "          [ 1.8438e+00, -1.6406e+00,  8.9844e-01,  ...,  8.2031e-01,\n",
       "            4.2236e-02, -1.5137e-02],\n",
       "          [ 9.2188e-01, -1.7266e+00,  4.1016e-01,  ...,  1.4609e+00,\n",
       "            6.7188e-01, -4.5703e-01],\n",
       "          ...,\n",
       "          [-1.4531e+00,  2.4219e+00, -1.5234e+00,  ...,  4.8584e-02,\n",
       "            7.3828e-01,  3.7891e-01],\n",
       "          [ 1.3047e+00,  2.1250e+00, -3.7891e-01,  ...,  1.9766e+00,\n",
       "           -3.8672e-01,  6.9922e-01],\n",
       "          [ 1.3125e+00, -2.4414e-01,  4.7266e-01,  ...,  2.7734e-01,\n",
       "            2.3828e-01,  6.3965e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[ 1.4954e-02, -6.3171e-03, -9.4604e-03,  ...,  7.8735e-03,\n",
       "            1.0803e-02, -1.9775e-02],\n",
       "          [ 5.4443e-02, -5.2734e-01,  7.3438e-01,  ..., -9.1406e-01,\n",
       "           -9.1406e-01,  1.3594e+00],\n",
       "          [-5.9375e-01, -4.1406e-01,  5.3906e-01,  ..., -1.7773e-01,\n",
       "           -1.2188e+00,  1.3359e+00],\n",
       "          ...,\n",
       "          [-1.3672e+00, -1.3594e+00,  6.6016e-01,  ..., -2.7734e-01,\n",
       "            9.3750e-01,  6.0425e-03],\n",
       "          [-9.4531e-01, -7.8613e-02, -7.9297e-01,  ...,  1.0781e+00,\n",
       "            7.5391e-01, -5.3906e-01],\n",
       "          [-1.1797e+00, -5.7422e-01,  6.4062e-01,  ...,  2.2852e-01,\n",
       "           -2.4170e-02, -2.6172e-01]],\n",
       "\n",
       "         [[-4.8218e-03, -2.4872e-03,  2.4780e-02,  ...,  6.7444e-03,\n",
       "            5.4932e-03,  1.1169e-02],\n",
       "          [-9.9121e-02,  1.1621e-01, -1.6309e-01,  ...,  4.2383e-01,\n",
       "           -1.0547e+00, -2.6953e-01],\n",
       "          [-4.9316e-02,  4.1797e-01, -8.7500e-01,  ...,  8.5156e-01,\n",
       "           -6.7871e-02,  1.2598e-01],\n",
       "          ...,\n",
       "          [-7.8125e-01,  1.6797e-01,  7.6953e-01,  ...,  1.3125e+00,\n",
       "           -6.5625e-01,  1.4160e-01],\n",
       "          [-7.9297e-01, -1.1250e+00,  8.5156e-01,  ..., -1.2891e-01,\n",
       "            3.5547e-01,  5.8203e-01],\n",
       "          [-1.3281e-01,  7.7734e-01, -1.4160e-02,  ...,  6.8848e-02,\n",
       "            2.0215e-01, -3.5352e-01]],\n",
       "\n",
       "         [[ 8.4839e-03,  1.1658e-02,  2.8687e-02,  ..., -2.3560e-02,\n",
       "            2.7344e-02,  3.8330e-02],\n",
       "          [ 1.1641e+00,  6.9531e-01,  1.2988e-01,  ...,  8.3203e-01,\n",
       "            9.7656e-01, -3.3203e-01],\n",
       "          [ 4.1992e-01,  9.0234e-01,  4.3945e-01,  ...,  1.2188e+00,\n",
       "            9.6094e-01, -7.6172e-01],\n",
       "          ...,\n",
       "          [ 3.1006e-02, -1.4941e-01, -1.4531e+00,  ...,  5.0000e-01,\n",
       "           -3.5742e-01, -8.5938e-02],\n",
       "          [ 3.1641e-01, -2.7148e-01, -1.6309e-01,  ...,  2.6562e-01,\n",
       "           -7.0703e-01, -6.5234e-01],\n",
       "          [ 2.0000e+00,  7.2266e-01,  2.7734e-01,  ...,  1.0703e+00,\n",
       "            2.8516e-01,  1.1797e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.3489e-02,  6.9580e-03, -1.1902e-02,  ..., -8.1787e-03,\n",
       "            9.9487e-03, -3.6316e-03],\n",
       "          [ 3.2031e-01, -3.3594e-01, -5.0781e-01,  ...,  3.0078e-01,\n",
       "           -1.2988e-01,  1.2734e+00],\n",
       "          [-5.9375e-01, -4.7852e-01, -1.0078e+00,  ...,  2.5757e-02,\n",
       "            1.8848e-01,  8.5156e-01],\n",
       "          ...,\n",
       "          [ 1.5918e-01,  3.1250e-01, -3.4668e-02,  ...,  2.1973e-01,\n",
       "            1.1484e+00, -9.0625e-01],\n",
       "          [-5.1172e-01,  4.1211e-01, -5.5469e-01,  ...,  2.1406e+00,\n",
       "            3.8867e-01, -1.8281e+00],\n",
       "          [ 8.0859e-01,  7.5391e-01,  4.8047e-01,  ...,  6.0156e-01,\n",
       "            3.6914e-01, -3.1250e-01]],\n",
       "\n",
       "         [[ 2.5513e-02,  1.8921e-02, -1.8066e-02,  ..., -5.9082e-02,\n",
       "            7.5195e-02, -2.4414e-02],\n",
       "          [-4.4727e-01, -4.3945e-01,  4.4336e-01,  ..., -1.2031e+00,\n",
       "           -1.1406e+00,  1.2188e+00],\n",
       "          [-1.1562e+00,  6.9531e-01,  2.3828e-01,  ..., -1.2969e+00,\n",
       "           -8.9062e-01,  8.7500e-01],\n",
       "          ...,\n",
       "          [ 5.7422e-01, -1.7090e-01,  6.0547e-02,  ..., -4.7070e-01,\n",
       "           -5.6641e-01,  6.7969e-01],\n",
       "          [ 5.9375e-01,  1.5527e-01, -6.7578e-01,  ...,  1.1250e+00,\n",
       "           -2.2031e+00,  7.3828e-01],\n",
       "          [-1.9897e-02, -7.5000e-01, -6.0156e-01,  ...,  2.7734e-01,\n",
       "           -1.1328e+00,  1.2969e+00]],\n",
       "\n",
       "         [[ 2.5757e-02, -6.3324e-04, -1.2634e-02,  ..., -7.3853e-03,\n",
       "           -1.4893e-02, -8.8501e-03],\n",
       "          [-4.3555e-01, -4.4434e-02, -2.2363e-01,  ..., -3.2031e-01,\n",
       "            7.8125e-01, -1.0547e+00],\n",
       "          [-2.5195e-01, -1.0449e-01, -8.8672e-01,  ..., -8.2812e-01,\n",
       "            9.6484e-01, -2.2852e-01],\n",
       "          ...,\n",
       "          [-7.9956e-03,  8.3984e-01,  2.8320e-01,  ..., -5.0000e-01,\n",
       "           -1.3516e+00, -3.4375e-01],\n",
       "          [-1.0312e+00, -2.3926e-01,  6.7578e-01,  ..., -3.8281e-01,\n",
       "           -5.4688e-01,  4.8828e-02],\n",
       "          [ 5.8203e-01,  4.0625e-01, -1.6895e-01,  ..., -1.5918e-01,\n",
       "           -1.4062e-01,  1.7773e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 2.1057e-03,  3.7109e-02, -1.8082e-03,  ...,  3.0664e-01,\n",
       "            6.5918e-02, -1.9434e-01],\n",
       "          [-8.1250e-01, -9.1016e-01, -1.2344e+00,  ..., -9.7656e-01,\n",
       "           -8.3984e-01,  3.1445e-01],\n",
       "          [-8.9844e-01, -8.7109e-01,  9.1309e-02,  ..., -5.4688e-01,\n",
       "           -2.7344e-01, -4.2773e-01],\n",
       "          ...,\n",
       "          [-3.2227e-02, -4.1602e-01,  1.0078e+00,  ...,  1.2207e-02,\n",
       "            8.9453e-01,  9.3750e-02],\n",
       "          [ 8.0078e-01, -1.0156e+00, -2.5586e-01,  ..., -1.3984e+00,\n",
       "            1.6172e+00, -5.8203e-01],\n",
       "          [-7.5000e-01,  3.6719e-01, -1.0234e+00,  ...,  2.9883e-01,\n",
       "            1.6328e+00,  1.6211e-01]],\n",
       "\n",
       "         [[-7.2098e-04,  1.3062e-02,  1.0254e-02,  ...,  6.2109e-01,\n",
       "            3.0273e-01,  5.2734e-01],\n",
       "          [ 1.5391e+00, -5.3906e-01, -5.3711e-02,  ..., -2.5156e+00,\n",
       "           -3.7695e-01, -2.7812e+00],\n",
       "          [ 6.6406e-01, -1.2207e-01,  5.1514e-02,  ..., -1.8516e+00,\n",
       "            6.0938e-01, -2.9219e+00],\n",
       "          ...,\n",
       "          [-1.2344e+00,  5.7812e-01,  9.7168e-02,  ..., -1.9844e+00,\n",
       "            4.4922e-01, -2.8750e+00],\n",
       "          [-4.6875e-01, -7.5000e-01,  9.5703e-01,  ..., -5.6250e-01,\n",
       "            1.0859e+00, -1.8516e+00],\n",
       "          [ 1.3984e+00,  4.7363e-02, -1.7383e-01,  ..., -2.6719e+00,\n",
       "           -1.1953e+00, -2.7812e+00]],\n",
       "\n",
       "         [[ 1.4954e-02,  1.6357e-02, -1.0605e-03,  ...,  4.6387e-02,\n",
       "            1.2656e+00, -1.2656e+00],\n",
       "          [-1.9922e-01,  2.4902e-02,  1.3281e-01,  ..., -1.1406e+00,\n",
       "           -1.6875e+00,  3.5469e+00],\n",
       "          [ 1.8750e-01, -1.1230e-01,  5.8594e-01,  ..., -6.4453e-02,\n",
       "           -1.4375e+00,  4.0938e+00],\n",
       "          ...,\n",
       "          [-1.0303e-01, -1.1719e-01, -1.3428e-02,  ..., -1.3438e+00,\n",
       "           -2.6719e+00,  3.0938e+00],\n",
       "          [ 1.0645e-01,  1.0547e-01,  2.2266e-01,  ..., -1.4062e+00,\n",
       "           -3.3750e+00,  2.5781e+00],\n",
       "          [ 2.6172e-01, -6.7383e-02, -9.0820e-02,  ...,  2.2266e-01,\n",
       "           -2.8281e+00,  4.8750e+00]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-4.1504e-03,  3.3569e-03, -7.5378e-03,  ..., -1.6016e-01,\n",
       "           -2.5391e-01, -1.4551e-01],\n",
       "          [ 7.8125e-01,  3.2617e-01,  3.9844e-01,  ..., -8.3203e-01,\n",
       "            5.1953e-01, -4.6484e-01],\n",
       "          [-5.9375e-01, -3.5742e-01,  1.2598e-01,  ...,  3.1738e-02,\n",
       "           -9.4141e-01,  6.1719e-01],\n",
       "          ...,\n",
       "          [-1.6406e-01, -3.6719e-01,  8.5938e-01,  ..., -1.8066e-01,\n",
       "            1.1406e+00,  7.9590e-02],\n",
       "          [ 1.5156e+00,  5.0391e-01, -5.5469e-01,  ..., -7.2754e-02,\n",
       "            1.6562e+00,  8.5449e-02],\n",
       "          [ 2.1719e+00,  7.4219e-01, -3.3984e-01,  ...,  7.8906e-01,\n",
       "            4.6289e-01,  4.1016e-01]],\n",
       "\n",
       "         [[-1.1597e-02,  1.2329e-02, -2.7100e-02,  ...,  9.9219e-01,\n",
       "           -1.3965e-01,  4.0039e-01],\n",
       "          [ 5.6250e-01, -6.4453e-01,  5.8594e-01,  ..., -3.3906e+00,\n",
       "            1.9766e+00, -4.6387e-02],\n",
       "          [ 1.2402e-01,  2.3438e-01,  4.7070e-01,  ..., -3.9531e+00,\n",
       "            1.8906e+00, -3.4570e-01],\n",
       "          ...,\n",
       "          [-3.4668e-02,  3.9453e-01, -8.9844e-02,  ..., -2.9688e+00,\n",
       "            2.2812e+00, -8.4766e-01],\n",
       "          [-1.1035e-01,  1.1328e-01, -4.4336e-01,  ..., -1.3984e+00,\n",
       "            3.1250e-01,  4.5654e-02],\n",
       "          [-7.2656e-01, -1.2812e+00, -1.2988e-01,  ..., -3.0000e+00,\n",
       "            1.4062e+00,  1.9434e-01]],\n",
       "\n",
       "         [[ 7.2327e-03, -7.4463e-03,  1.3855e-02,  ...,  2.3242e-01,\n",
       "           -3.1250e-01, -7.9297e-01],\n",
       "          [ 4.4434e-02, -1.4746e-01, -1.2402e-01,  ..., -1.4844e-01,\n",
       "            1.3672e+00,  2.1406e+00],\n",
       "          [-1.5234e-01,  1.8164e-01, -7.8125e-01,  ..., -1.3125e+00,\n",
       "            1.3203e+00,  2.6719e+00],\n",
       "          ...,\n",
       "          [ 6.0303e-02,  2.8125e-01,  1.2695e-01,  ..., -8.8672e-01,\n",
       "            6.4062e-01,  3.9375e+00],\n",
       "          [-1.0938e-01,  4.7363e-02,  4.6484e-01,  ..., -3.0078e-01,\n",
       "            1.5312e+00,  3.6875e+00],\n",
       "          [ 3.3203e-01, -4.3750e-01, -4.2969e-01,  ..., -1.5781e+00,\n",
       "            1.9062e+00,  3.0156e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-7.8125e-03,  3.5553e-03, -1.5076e-02,  ..., -8.1177e-03,\n",
       "            2.0409e-04, -9.2773e-03],\n",
       "          [-8.1250e-01,  1.1953e+00,  5.0000e-01,  ...,  1.2344e+00,\n",
       "            1.2578e+00, -7.3047e-01],\n",
       "          [-4.6289e-01, -1.6406e-01,  4.6680e-01,  ...,  4.9023e-01,\n",
       "            1.0469e+00, -3.3398e-01],\n",
       "          ...,\n",
       "          [-2.0898e-01, -8.3984e-01, -3.1494e-02,  ...,  1.0742e-01,\n",
       "           -4.7656e-01,  7.0703e-01],\n",
       "          [ 2.1680e-01, -1.7773e-01,  9.3359e-01,  ..., -5.0781e-01,\n",
       "            8.1543e-02, -4.5410e-02],\n",
       "          [ 5.1953e-01,  3.6523e-01,  6.4062e-01,  ...,  1.5938e+00,\n",
       "            9.3359e-01, -1.1172e+00]],\n",
       "\n",
       "         [[ 6.5994e-04,  4.7913e-03,  3.0823e-03,  ...,  1.7014e-03,\n",
       "            4.5471e-03,  1.8677e-02],\n",
       "          [-5.2734e-02,  8.0078e-01,  3.3594e-01,  ..., -6.0730e-03,\n",
       "           -4.1211e-01, -3.4961e-01],\n",
       "          [-6.4453e-01,  5.5859e-01,  5.3516e-01,  ..., -1.4746e-01,\n",
       "           -7.4219e-01, -2.6562e-01],\n",
       "          ...,\n",
       "          [-6.5234e-01,  6.0791e-02, -4.8633e-01,  ..., -1.2598e-01,\n",
       "            6.2891e-01,  3.0078e-01],\n",
       "          [-1.4141e+00, -4.8047e-01, -1.4844e-01,  ...,  1.2578e+00,\n",
       "           -1.0107e-01, -3.9648e-01],\n",
       "          [-5.5469e-01,  1.2031e+00, -5.5078e-01,  ..., -3.7695e-01,\n",
       "           -1.8945e-01,  8.7500e-01]],\n",
       "\n",
       "         [[-2.6123e-02, -2.0142e-02, -3.1738e-02,  ...,  1.1536e-02,\n",
       "            3.5889e-02, -5.7129e-02],\n",
       "          [-1.0859e+00, -2.4609e-01, -2.2656e-01,  ...,  5.0391e-01,\n",
       "           -6.9531e-01,  3.1250e-01],\n",
       "          [-7.4219e-01, -7.5684e-02, -7.5781e-01,  ..., -2.1680e-01,\n",
       "           -1.2422e+00,  1.8555e-01],\n",
       "          ...,\n",
       "          [ 1.3770e-01,  6.5625e-01,  6.0547e-01,  ...,  7.0312e-01,\n",
       "            1.4062e-01, -2.2363e-01],\n",
       "          [-2.9883e-01,  9.7266e-01,  1.8848e-01,  ...,  9.8438e-01,\n",
       "            2.9883e-01, -6.7578e-01],\n",
       "          [-4.7852e-01,  2.8125e-01,  6.6406e-01,  ..., -6.8359e-02,\n",
       "            8.6426e-02,  1.2988e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.0742e-02,  2.9053e-02, -1.4404e-02,  ..., -3.6621e-02,\n",
       "           -5.3467e-02, -4.0527e-02],\n",
       "          [ 3.1250e-01,  6.7969e-01, -1.6016e-01,  ...,  8.2812e-01,\n",
       "            2.0801e-01,  6.4453e-02],\n",
       "          [ 7.3047e-01,  3.1445e-01, -5.5469e-01,  ...,  2.6367e-01,\n",
       "            4.6289e-01,  5.3125e-01],\n",
       "          ...,\n",
       "          [ 1.3750e+00, -7.2656e-01,  1.1719e+00,  ...,  1.1250e+00,\n",
       "            7.9688e-01,  5.7812e-01],\n",
       "          [ 7.5391e-01,  2.1191e-01,  5.9375e-01,  ..., -2.6562e-01,\n",
       "            1.0938e+00, -1.5938e+00],\n",
       "          [ 5.7422e-01, -3.5742e-01, -3.4375e-01,  ...,  1.0078e+00,\n",
       "            3.7891e-01,  2.0312e-01]],\n",
       "\n",
       "         [[-1.1749e-03,  1.4771e-02,  9.3994e-03,  ...,  1.0803e-02,\n",
       "           -1.5259e-02, -3.2715e-02],\n",
       "          [-4.2383e-01,  7.1484e-01,  5.6152e-02,  ..., -2.1851e-02,\n",
       "            8.0859e-01,  6.1035e-02],\n",
       "          [-9.2188e-01,  4.4922e-01, -5.1172e-01,  ...,  3.3594e-01,\n",
       "            6.7578e-01, -2.5977e-01],\n",
       "          ...,\n",
       "          [ 2.2656e-01, -6.7578e-01, -1.5259e-02,  ...,  2.9907e-02,\n",
       "           -1.8262e-01,  4.6289e-01],\n",
       "          [ 7.3438e-01,  1.1035e-01,  4.0430e-01,  ...,  6.7578e-01,\n",
       "           -1.2578e+00,  1.1016e+00],\n",
       "          [ 3.2617e-01,  2.7148e-01,  5.0391e-01,  ..., -4.4922e-01,\n",
       "            1.9434e-01,  3.1445e-01]],\n",
       "\n",
       "         [[-1.3794e-02, -2.6733e-02, -1.6357e-02,  ...,  1.1719e-02,\n",
       "            1.8768e-03, -1.2085e-02],\n",
       "          [-1.7773e-01,  3.4375e-01, -4.3701e-02,  ...,  1.4688e+00,\n",
       "           -2.8320e-01, -3.3398e-01],\n",
       "          [ 6.4844e-01,  1.7188e-01, -3.0859e-01,  ...,  1.2969e+00,\n",
       "            2.5195e-01,  2.0508e-01],\n",
       "          ...,\n",
       "          [ 1.0078e+00,  2.5781e-01,  7.1875e-01,  ...,  1.2422e+00,\n",
       "           -9.3262e-02,  7.4707e-02],\n",
       "          [ 1.0469e+00,  1.8945e-01,  7.4609e-01,  ...,  2.2852e-01,\n",
       "           -8.6719e-01,  2.6758e-01],\n",
       "          [ 1.1875e+00,  9.6484e-01,  4.1016e-01,  ...,  1.0000e+00,\n",
       "           -3.9258e-01, -8.2031e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[ 8.3008e-03, -1.3855e-02, -2.3071e-02,  ..., -2.5146e-02,\n",
       "           -1.2061e-01,  8.7280e-03],\n",
       "          [-1.3906e+00, -2.0020e-01,  5.3906e-01,  ...,  4.3750e-01,\n",
       "           -3.9844e-01, -4.4922e-02],\n",
       "          [-8.0078e-01,  4.1602e-01,  1.3281e-01,  ..., -5.0391e-01,\n",
       "           -2.0312e-01, -2.2969e+00],\n",
       "          ...,\n",
       "          [ 6.7969e-01,  6.0059e-02, -1.4453e-01,  ...,  9.8047e-01,\n",
       "            1.6328e+00, -2.9688e+00],\n",
       "          [ 1.0625e+00,  7.2656e-01, -7.2266e-01,  ...,  7.0312e-01,\n",
       "           -4.3555e-01, -1.6328e+00],\n",
       "          [-9.2969e-01,  6.0938e-01, -3.8477e-01,  ..., -1.7090e-01,\n",
       "            1.0156e+00,  1.3594e+00]],\n",
       "\n",
       "         [[-1.2436e-03, -4.3640e-03, -5.7983e-03,  ..., -2.9541e-02,\n",
       "            3.2715e-02, -2.6758e-01],\n",
       "          [ 1.7969e-01, -5.4297e-01,  4.9805e-01,  ..., -1.1875e+00,\n",
       "            1.7266e+00, -4.9414e-01],\n",
       "          [ 3.4766e-01,  1.9043e-02,  5.3125e-01,  ..., -2.5938e+00,\n",
       "            1.8203e+00,  1.1484e+00],\n",
       "          ...,\n",
       "          [ 3.5938e-01, -7.7148e-02,  1.4941e-01,  ...,  9.2578e-01,\n",
       "            7.7734e-01, -3.5469e+00],\n",
       "          [ 7.5391e-01, -2.5156e+00,  7.2266e-02,  ...,  9.9219e-01,\n",
       "           -1.1768e-01,  6.7969e-01],\n",
       "          [ 1.8457e-01, -4.7461e-01, -2.3828e-01,  ...,  1.0303e-01,\n",
       "            2.4414e-01,  3.2617e-01]],\n",
       "\n",
       "         [[ 1.1353e-02, -6.5002e-03,  1.3855e-02,  ..., -4.3213e-02,\n",
       "            3.0762e-02,  2.2266e-01],\n",
       "          [ 1.3750e+00, -2.5195e-01,  7.6953e-01,  ..., -1.0938e+00,\n",
       "           -3.9258e-01, -1.0391e+00],\n",
       "          [ 4.2383e-01,  7.6953e-01,  2.1094e-01,  ..., -9.9219e-01,\n",
       "           -4.3945e-01, -2.7734e-01],\n",
       "          ...,\n",
       "          [-1.4258e-01, -9.4141e-01, -1.7266e+00,  ..., -9.4531e-01,\n",
       "            1.7383e-01, -5.0781e-01],\n",
       "          [ 7.6562e-01, -1.2812e+00, -1.5234e+00,  ...,  5.7812e-01,\n",
       "           -1.1719e-01, -9.9609e-01],\n",
       "          [ 9.4922e-01, -6.7188e-01,  5.5469e-01,  ..., -1.1641e+00,\n",
       "           -9.8828e-01, -6.6406e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.9775e-02, -2.5146e-02, -1.7456e-02,  ..., -3.5742e-01,\n",
       "           -5.2734e-01,  6.0059e-02],\n",
       "          [-3.7500e-01, -6.7188e-01,  6.9922e-01,  ..., -2.0781e+00,\n",
       "           -6.8359e-01,  1.2656e+00],\n",
       "          [-3.0859e-01, -1.2578e+00,  9.8633e-02,  ..., -1.8203e+00,\n",
       "            2.5586e-01,  6.2109e-01],\n",
       "          ...,\n",
       "          [-8.6328e-01,  4.2773e-01, -2.3438e-02,  ...,  9.5703e-01,\n",
       "            2.6758e-01, -8.6719e-01],\n",
       "          [-4.7852e-01, -4.7656e-01,  1.1719e+00,  ...,  9.3359e-01,\n",
       "           -1.9453e+00, -2.3535e-01],\n",
       "          [ 4.4727e-01,  6.5234e-01, -1.6602e-02,  ...,  8.8867e-02,\n",
       "           -8.3203e-01, -1.9219e+00]],\n",
       "\n",
       "         [[ 1.5869e-02, -1.6235e-02, -1.5869e-02,  ..., -1.3184e-02,\n",
       "            1.8311e-02,  1.3574e-01],\n",
       "          [-2.9688e-01, -2.3438e-01, -1.1562e+00,  ...,  1.4297e+00,\n",
       "           -5.3125e-01, -1.1016e+00],\n",
       "          [-1.2891e-01, -1.1641e+00, -2.4219e-01,  ...,  1.4531e+00,\n",
       "           -1.2734e+00, -1.1816e-01],\n",
       "          ...,\n",
       "          [ 4.0234e-01, -2.2168e-01,  4.8047e-01,  ...,  2.6406e+00,\n",
       "           -5.2734e-02, -1.6328e+00],\n",
       "          [ 1.8848e-01, -1.4355e-01,  8.1641e-01,  ...,  2.4688e+00,\n",
       "            7.9102e-02,  6.7188e-01],\n",
       "          [ 7.6562e-01, -3.0078e-01,  2.7148e-01,  ...,  1.8359e+00,\n",
       "           -1.7031e+00, -1.1953e+00]],\n",
       "\n",
       "         [[ 2.4414e-03,  1.9775e-02, -1.6113e-02,  ...,  1.2695e-01,\n",
       "            1.9409e-02,  2.4414e-01],\n",
       "          [ 4.6289e-01, -1.6094e+00, -5.7812e-01,  ...,  5.8984e-01,\n",
       "           -1.9165e-02,  9.9609e-01],\n",
       "          [-2.6172e-01, -1.5137e-01, -5.8984e-01,  ...,  8.8672e-01,\n",
       "            1.1406e+00,  3.2031e-01],\n",
       "          ...,\n",
       "          [-4.9414e-01,  7.2656e-01, -7.0312e-01,  ..., -2.9883e-01,\n",
       "            1.0234e+00, -9.6484e-01],\n",
       "          [-1.7969e-01, -1.3906e+00,  1.9062e+00,  ..., -1.9375e+00,\n",
       "           -4.5898e-01, -4.9805e-01],\n",
       "          [-6.2891e-01, -6.4062e-01,  1.2734e+00,  ..., -3.9453e-01,\n",
       "            3.7500e-01, -2.3828e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-9.5215e-03,  3.0151e-02, -3.9551e-02,  ...,  2.5024e-02,\n",
       "            2.3842e-04,  1.3672e-02],\n",
       "          [ 3.3203e-01,  1.1562e+00,  7.5000e-01,  ..., -9.6484e-01,\n",
       "            7.1289e-02, -1.0703e+00],\n",
       "          [ 6.9824e-02,  6.6406e-01,  4.5117e-01,  ..., -3.8477e-01,\n",
       "            4.5654e-02,  7.6172e-02],\n",
       "          ...,\n",
       "          [-8.0078e-01, -4.9414e-01,  5.0000e-01,  ..., -3.0469e-01,\n",
       "            8.8672e-01, -6.6797e-01],\n",
       "          [ 3.2812e-01,  2.7930e-01,  1.1484e+00,  ...,  7.3438e-01,\n",
       "            1.0938e+00,  4.2383e-01],\n",
       "          [-1.1719e+00,  6.5625e-01,  6.0938e-01,  ..., -1.6484e+00,\n",
       "            4.3164e-01, -7.5000e-01]],\n",
       "\n",
       "         [[-9.0942e-03,  5.9891e-04,  1.2756e-02,  ...,  1.6602e-02,\n",
       "           -2.3041e-03, -2.3438e-02],\n",
       "          [ 1.0234e+00, -8.0859e-01, -3.8086e-01,  ...,  3.5156e-01,\n",
       "           -1.6172e+00, -6.9531e-01],\n",
       "          [ 1.1641e+00, -1.1328e+00, -4.5898e-01,  ...,  6.0547e-01,\n",
       "           -1.4062e+00,  9.5215e-03],\n",
       "          ...,\n",
       "          [-2.1387e-01, -2.4902e-01,  1.1641e+00,  ...,  2.2559e-01,\n",
       "            1.0156e+00,  9.9219e-01],\n",
       "          [-1.0859e+00, -2.9492e-01,  8.9453e-01,  ...,  4.9805e-02,\n",
       "            1.1108e-02,  6.5625e-01],\n",
       "          [-4.7852e-01, -9.0234e-01, -6.2500e-01,  ...,  7.6294e-03,\n",
       "           -2.0605e-01,  3.2617e-01]],\n",
       "\n",
       "         [[-2.5482e-03,  1.5991e-02, -4.5410e-02,  ..., -2.0508e-02,\n",
       "            1.0803e-02,  2.0874e-02],\n",
       "          [-1.1250e+00,  1.1084e-01, -6.9531e-01,  ..., -6.9141e-01,\n",
       "           -2.7930e-01, -2.4023e-01],\n",
       "          [-5.3125e-01,  6.3281e-01, -4.9805e-01,  ..., -5.8984e-01,\n",
       "           -5.3516e-01, -3.0078e-01],\n",
       "          ...,\n",
       "          [ 9.8438e-01,  1.0000e+00, -1.9629e-01,  ..., -1.2422e+00,\n",
       "           -4.9023e-01, -8.7500e-01],\n",
       "          [ 7.4609e-01,  2.3242e-01, -1.0469e+00,  ..., -5.0781e-01,\n",
       "           -3.5156e-01,  2.6953e-01],\n",
       "          [ 9.4922e-01,  9.6680e-02, -7.5391e-01,  ...,  4.9219e-01,\n",
       "            6.1523e-02, -2.8516e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 9.9487e-03,  7.1335e-04,  1.9409e-02,  ..., -1.0925e-02,\n",
       "           -6.2866e-03,  1.7471e-03],\n",
       "          [ 5.8594e-01,  3.3789e-01,  3.6328e-01,  ...,  2.9297e-03,\n",
       "            3.2654e-03,  4.2578e-01],\n",
       "          [ 5.8203e-01,  3.0078e-01,  3.4766e-01,  ...,  1.2305e-01,\n",
       "            4.5703e-01,  1.1016e+00],\n",
       "          ...,\n",
       "          [-1.4221e-02,  5.3906e-01, -4.2188e-01,  ...,  7.2266e-01,\n",
       "           -1.0859e+00,  3.6719e-01],\n",
       "          [-3.0859e-01, -4.4727e-01, -1.2578e+00,  ...,  5.7617e-02,\n",
       "           -6.2109e-01,  6.9141e-01],\n",
       "          [-2.3828e-01, -4.9805e-01, -2.7344e-01,  ...,  4.2383e-01,\n",
       "            3.4375e-01,  7.2266e-01]],\n",
       "\n",
       "         [[-3.5858e-04,  1.6235e-02, -7.2937e-03,  ..., -1.2878e-02,\n",
       "            5.3711e-03,  1.0986e-02],\n",
       "          [-1.0156e+00, -2.7161e-03, -8.2812e-01,  ..., -4.1504e-02,\n",
       "            4.5117e-01,  4.1992e-01],\n",
       "          [-6.3281e-01, -8.8867e-02, -2.5586e-01,  ..., -1.0681e-04,\n",
       "           -5.2344e-01,  2.4121e-01],\n",
       "          ...,\n",
       "          [-7.7734e-01,  9.1406e-01, -1.5564e-02,  ..., -4.7070e-01,\n",
       "           -1.9434e-01, -5.2344e-01],\n",
       "          [-2.6172e-01,  5.1562e-01,  2.8125e-01,  ..., -4.1406e-01,\n",
       "            6.1719e-01,  3.2031e-01],\n",
       "          [ 5.4297e-01,  6.3281e-01,  2.1484e-01,  ..., -1.5234e+00,\n",
       "           -8.0566e-02, -1.1133e-01]],\n",
       "\n",
       "         [[ 1.3855e-02, -5.0659e-03,  1.2512e-02,  ..., -7.8125e-03,\n",
       "           -3.2471e-02, -1.4160e-02],\n",
       "          [ 2.0605e-01,  3.1445e-01, -1.5137e-01,  ..., -6.1035e-02,\n",
       "            8.2031e-01,  3.8477e-01],\n",
       "          [-1.7090e-01,  8.7402e-02,  1.7773e-01,  ...,  3.7891e-01,\n",
       "            1.9219e+00,  4.4922e-01],\n",
       "          ...,\n",
       "          [ 1.1406e+00,  9.1016e-01,  3.3008e-01,  ...,  1.5918e-01,\n",
       "           -1.0703e+00, -6.6016e-01],\n",
       "          [-1.9922e-01,  3.7109e-01, -4.1797e-01,  ...,  1.0156e+00,\n",
       "           -1.2969e+00, -1.0156e+00],\n",
       "          [ 1.1133e-01,  1.6875e+00,  6.9922e-01,  ..., -6.3281e-01,\n",
       "            2.5156e+00, -1.0078e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.0864e-02, -3.4180e-02,  4.5410e-02,  ..., -8.7109e-01,\n",
       "           -9.8633e-02,  3.9673e-03],\n",
       "          [ 1.4219e+00, -5.0781e-01, -4.9609e-01,  ...,  1.5625e+00,\n",
       "            9.3750e-01,  1.8750e-01],\n",
       "          [ 6.4844e-01, -2.7148e-01, -1.7734e+00,  ...,  1.9531e+00,\n",
       "            1.0469e+00,  4.5703e-01],\n",
       "          ...,\n",
       "          [-1.1328e-01, -3.7598e-02,  0.0000e+00,  ...,  1.2656e+00,\n",
       "            2.9419e-02, -3.8477e-01],\n",
       "          [-6.9141e-01, -6.1719e-01,  1.7969e-01,  ...,  1.4688e+00,\n",
       "           -7.3047e-01,  4.3945e-01],\n",
       "          [ 5.8594e-01,  1.8164e-01,  2.4219e-01,  ...,  1.7109e+00,\n",
       "           -1.1084e-01, -1.5430e-01]],\n",
       "\n",
       "         [[ 6.0303e-02, -1.0205e-01,  1.5747e-02,  ...,  5.9082e-02,\n",
       "           -3.1445e-01,  5.1953e-01],\n",
       "          [-9.1016e-01, -1.0352e-01,  1.3867e-01,  ..., -2.1094e+00,\n",
       "           -1.1250e+00,  2.2266e-01],\n",
       "          [-1.1094e+00,  4.1016e-01, -1.2207e-01,  ..., -2.7344e+00,\n",
       "           -7.8125e-01,  3.9648e-01],\n",
       "          ...,\n",
       "          [ 1.6953e+00,  5.5469e-01, -5.0781e-01,  ...,  4.2188e-01,\n",
       "           -9.5312e-01, -1.9766e+00],\n",
       "          [ 7.4707e-02,  1.3594e+00,  1.6016e-01,  ..., -2.1875e+00,\n",
       "            4.8633e-01, -2.2344e+00],\n",
       "          [-6.0938e-01, -2.1484e-02, -7.6172e-02,  ..., -1.4844e+00,\n",
       "           -1.9297e+00, -9.2969e-01]],\n",
       "\n",
       "         [[ 3.8818e-02,  1.8555e-02,  3.3691e-02,  ..., -8.4229e-03,\n",
       "            3.1055e-01, -1.6562e+00],\n",
       "          [-5.7422e-01, -1.1475e-01,  1.9989e-03,  ..., -1.3184e-01,\n",
       "           -1.2188e+00,  1.0703e+00],\n",
       "          [-3.6914e-01,  5.9082e-02,  4.1260e-02,  ..., -5.5469e-01,\n",
       "           -1.1172e+00,  1.1641e+00],\n",
       "          ...,\n",
       "          [ 5.6641e-01, -4.6875e-01,  6.2500e-01,  ..., -2.1406e+00,\n",
       "           -9.2773e-02,  1.2158e-01],\n",
       "          [-5.2344e-01, -4.6289e-01,  5.4297e-01,  ..., -1.2188e+00,\n",
       "           -3.1641e-01,  7.6172e-01],\n",
       "          [-8.5938e-01,  1.1816e-01, -2.6367e-01,  ..., -2.0156e+00,\n",
       "           -4.1602e-01,  5.7422e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 5.7129e-02, -3.1738e-02,  1.1230e-02,  ...,  8.2422e-01,\n",
       "           -7.3828e-01,  1.4844e+00],\n",
       "          [-5.5469e-01, -4.4141e-01,  2.9883e-01,  ...,  3.6133e-01,\n",
       "            2.2344e+00, -4.7656e-01],\n",
       "          [-4.8047e-01, -4.2578e-01, -1.7578e-01,  ..., -3.8867e-01,\n",
       "            2.0156e+00, -8.3594e-01],\n",
       "          ...,\n",
       "          [ 2.8125e-01, -3.9648e-01, -1.9727e-01,  ..., -3.9258e-01,\n",
       "            1.8672e+00, -1.0234e+00],\n",
       "          [ 4.6875e-02,  1.4609e+00,  9.9219e-01,  ...,  3.6621e-02,\n",
       "            1.5938e+00,  4.9805e-01],\n",
       "          [-8.9844e-01,  4.9414e-01,  6.0547e-01,  ..., -2.9883e-01,\n",
       "            1.7500e+00, -1.3750e+00]],\n",
       "\n",
       "         [[ 4.7363e-02,  3.4912e-02, -2.1606e-02,  ..., -4.3945e-02,\n",
       "           -5.8203e-01,  1.9141e-01],\n",
       "          [-2.4902e-01,  2.4219e-01, -2.0898e-01,  ...,  1.9336e-01,\n",
       "            4.6484e-01,  9.9609e-01],\n",
       "          [-6.6406e-01,  6.2988e-02,  7.0312e-02,  ...,  4.5312e-01,\n",
       "           -2.7710e-02,  1.0703e+00],\n",
       "          ...,\n",
       "          [ 5.7422e-01,  6.0938e-01, -5.1562e-01,  ...,  9.3750e-01,\n",
       "           -9.2188e-01, -7.3828e-01],\n",
       "          [-1.4941e-01, -1.7090e-01, -4.4141e-01,  ..., -5.5908e-02,\n",
       "           -4.6094e-01,  1.1953e+00],\n",
       "          [ 6.5625e-01, -9.7656e-04, -5.6641e-01,  ...,  1.5312e+00,\n",
       "            6.4453e-01,  1.0205e-01]],\n",
       "\n",
       "         [[ 4.1748e-02,  1.9531e-02, -3.2471e-02,  ..., -4.6484e-01,\n",
       "            7.1875e-01, -9.2163e-03],\n",
       "          [-3.1055e-01,  6.0547e-01,  3.1055e-01,  ..., -6.9531e-01,\n",
       "            1.7656e+00, -8.2031e-01],\n",
       "          [-1.3594e+00,  1.9043e-01,  1.5625e+00,  ...,  5.0391e-01,\n",
       "            2.3125e+00,  4.0234e-01],\n",
       "          ...,\n",
       "          [ 8.8281e-01, -6.3281e-01,  3.5547e-01,  ..., -9.6094e-01,\n",
       "           -1.0254e-01,  4.8828e-01],\n",
       "          [ 4.0625e-01,  1.1094e+00, -1.4941e-01,  ...,  4.9023e-01,\n",
       "           -2.0508e-01, -4.9023e-01],\n",
       "          [-4.5703e-01,  9.2578e-01, -3.4766e-01,  ...,  6.5918e-02,\n",
       "           -1.8359e-01,  6.3672e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<AddBackward0>), tensor([[[[-3.6377e-02, -1.4648e-02,  2.3804e-02,  ...,  5.6458e-03,\n",
       "            2.7954e-02, -3.8086e-02],\n",
       "          [ 8.2031e-02, -1.6022e-03, -2.1875e-01,  ...,  1.3477e-01,\n",
       "           -8.0859e-01, -8.2520e-02],\n",
       "          [-1.3965e-01,  5.3125e-01, -3.2422e-01,  ..., -3.7354e-02,\n",
       "           -4.3555e-01,  9.0332e-02],\n",
       "          ...,\n",
       "          [-3.7842e-02, -6.4062e-01,  9.5703e-01,  ...,  3.7109e-01,\n",
       "           -9.7656e-02,  7.3828e-01],\n",
       "          [ 4.2969e-01,  2.9688e-01,  5.6250e-01,  ...,  9.4922e-01,\n",
       "           -9.4922e-01,  1.5625e-01],\n",
       "          [-4.5508e-01, -2.8320e-01,  2.1729e-02,  ..., -4.2969e-01,\n",
       "           -6.9922e-01,  1.8066e-01]],\n",
       "\n",
       "         [[-3.4943e-03,  2.0264e-02,  3.8330e-02,  ...,  2.1240e-02,\n",
       "            7.5378e-03, -5.8594e-03],\n",
       "          [-8.8281e-01, -1.6309e-01,  8.0859e-01,  ..., -2.8320e-01,\n",
       "            2.2754e-01,  7.2656e-01],\n",
       "          [-1.0078e+00, -6.2109e-01,  1.0703e+00,  ..., -3.7305e-01,\n",
       "            7.3242e-02,  7.6953e-01],\n",
       "          ...,\n",
       "          [-3.3594e-01, -8.3984e-01,  9.2285e-02,  ..., -6.5625e-01,\n",
       "            2.2363e-01, -3.6523e-01],\n",
       "          [-1.5391e+00, -3.5938e-01, -5.9766e-01,  ..., -7.3438e-01,\n",
       "            6.6406e-01, -4.4727e-01],\n",
       "          [ 1.5723e-01, -5.7812e-01,  2.3730e-01,  ..., -3.9062e-01,\n",
       "            1.2598e-01,  8.2031e-01]],\n",
       "\n",
       "         [[-3.5156e-02,  4.4434e-02, -7.2021e-03,  ..., -5.2002e-02,\n",
       "           -7.7820e-04, -8.1787e-03],\n",
       "          [-3.0469e-01, -4.5703e-01, -3.0078e-01,  ..., -3.1445e-01,\n",
       "            1.2695e-01, -3.4375e-01],\n",
       "          [-8.2031e-02, -1.4062e-01, -3.9844e-01,  ..., -2.8711e-01,\n",
       "           -5.5420e-02, -5.2734e-01],\n",
       "          ...,\n",
       "          [ 5.6250e-01, -5.3711e-02, -1.5039e-01,  ..., -4.1260e-02,\n",
       "           -6.4941e-02, -1.4465e-02],\n",
       "          [-2.2949e-01, -6.7969e-01,  1.8164e-01,  ..., -7.8125e-02,\n",
       "           -5.1172e-01,  8.4961e-02],\n",
       "          [ 1.9629e-01, -4.3164e-01, -4.1211e-01,  ..., -1.5039e-01,\n",
       "           -2.8516e-01, -5.4443e-02]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-5.4321e-03, -1.5747e-02,  2.5024e-03,  ..., -5.3406e-03,\n",
       "            2.4048e-02, -1.2207e-02],\n",
       "          [-1.3086e-01, -2.7734e-01, -4.1406e-01,  ...,  5.9326e-02,\n",
       "           -7.6562e-01, -8.9844e-01],\n",
       "          [-2.5000e-01,  7.5781e-01, -4.5471e-03,  ...,  3.4961e-01,\n",
       "           -7.5000e-01, -7.3438e-01],\n",
       "          ...,\n",
       "          [ 3.4961e-01,  1.7285e-01, -9.3359e-01,  ...,  4.4727e-01,\n",
       "            5.6250e-01,  8.7500e-01],\n",
       "          [-6.2988e-02,  4.7852e-01, -1.0234e+00,  ...,  8.3984e-01,\n",
       "            2.5586e-01,  1.1670e-01],\n",
       "          [-1.3086e-01, -2.7710e-02, -1.5234e+00,  ..., -2.2656e-01,\n",
       "            1.4844e-01, -4.1211e-01]],\n",
       "\n",
       "         [[ 6.8359e-03, -2.1851e-02, -1.6602e-02,  ..., -2.7954e-02,\n",
       "           -9.0332e-03, -2.6367e-02],\n",
       "          [-6.2891e-01, -1.6016e-01,  4.6387e-02,  ...,  1.3867e-01,\n",
       "           -6.8848e-02,  1.4551e-01],\n",
       "          [-3.3008e-01, -4.6484e-01, -6.9275e-03,  ..., -5.3906e-01,\n",
       "           -1.9727e-01, -2.9785e-02],\n",
       "          ...,\n",
       "          [ 1.9775e-02,  4.3555e-01,  5.0391e-01,  ..., -6.6016e-01,\n",
       "            6.4453e-02, -7.8516e-01],\n",
       "          [ 5.5176e-02, -4.2578e-01,  6.3672e-01,  ...,  7.7344e-01,\n",
       "            6.6406e-01, -6.8359e-01],\n",
       "          [ 1.6699e-01, -6.7578e-01,  2.4902e-01,  ...,  6.7871e-02,\n",
       "            2.0703e-01,  9.5703e-02]],\n",
       "\n",
       "         [[ 3.0365e-03, -1.9897e-02,  3.0396e-02,  ..., -5.7617e-02,\n",
       "            2.1240e-02, -5.5664e-02],\n",
       "          [ 1.1084e-01,  1.3672e-01,  1.4258e-01,  ...,  4.1797e-01,\n",
       "            8.9355e-02,  2.4316e-01],\n",
       "          [-6.6797e-01,  1.4453e-01,  4.4189e-02,  ...,  8.9453e-01,\n",
       "            2.6758e-01, -2.5000e-01],\n",
       "          ...,\n",
       "          [ 2.0752e-02,  5.2734e-01, -2.3315e-02,  ..., -2.6953e-01,\n",
       "            1.7188e-01,  2.2168e-01],\n",
       "          [ 5.1172e-01,  7.6562e-01, -1.4160e-01,  ..., -1.1377e-01,\n",
       "            3.8281e-01, -1.0449e-01],\n",
       "          [-1.7969e-01,  4.5703e-01, -2.1387e-01,  ..., -2.2852e-01,\n",
       "            8.5938e-01,  3.6523e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<TransposeBackward0>))), hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(**input_ids)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "LlamaForCausalLM(\n",
    "  (model): LlamaModel(\n",
    "    (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
    "    (layers): ModuleList(\n",
    "      (0-31): 32 x LlamaDecoderLayer(\n",
    "        (self_attn): LlamaSdpaAttention(\n",
    "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
    "          (rotary_emb): LlamaRotaryEmbedding()\n",
    "        )\n",
    "        (mlp): LlamaMLP(\n",
    "          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
    "          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
    "          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
    "          (act_fn): SiLU()\n",
    "        )\n",
    "        (input_layernorm): LlamaRMSNorm()\n",
    "        (post_attention_layernorm): LlamaRMSNorm()\n",
    "      )\n",
    "    )\n",
    "    (norm): LlamaRMSNorm()\n",
    "  )\n",
    "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in model.named_parameters():\n",
    "    print(k, v.shape)   "
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "model.embed_tokens.weight torch.Size([32000, 4096])\n",
    "model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096])\n",
    "model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096])\n",
    "model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008])\n",
    "model.layers.0.input_layernorm.weight torch.Size([4096])\n",
    "model.layers.0.post_attention_layernorm.weight torch.Size([4096])\n",
    "\n",
    "model.layers.31.self_attn.q_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.31.self_attn.k_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.31.self_attn.v_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.31.self_attn.o_proj.weight torch.Size([4096, 4096])\n",
    "model.layers.31.mlp.gate_proj.weight torch.Size([11008, 4096])\n",
    "model.layers.31.mlp.up_proj.weight torch.Size([11008, 4096])\n",
    "model.layers.31.mlp.down_proj.weight torch.Size([4096, 11008])\n",
    "model.layers.31.input_layernorm.weight torch.Size([4096])\n",
    "model.layers.31.post_attention_layernorm.weight torch.Size([4096])\n",
    "model.norm.weight torch.Size([4096])\n",
    "lm_head.weight torch.Size([32000, 4096])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ke2torch23cu121",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
